mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
5 Commits
mdrxy/vers
...
wfh/memory
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60313c3f62 | ||
|
|
b9bd33ff30 | ||
|
|
574d752c7e | ||
|
|
914489a0a1 | ||
|
|
506da41b97 |
@@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
||||||
from langchain.vectorstores.base import VectorStoreRetriever
|
from langchain.vectorstores.base import VectorStoreRetriever
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@@ -19,10 +20,13 @@ class AutoGPTMemory(BaseChatMemory):
|
|||||||
return get_prompt_input_key(inputs, self.memory_variables)
|
return get_prompt_input_key(inputs, self.memory_variables)
|
||||||
return self.input_key
|
return self.input_key
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
input_key = self._get_prompt_input_key(inputs)
|
input_key = self._get_prompt_input_key(inputs)
|
||||||
query = inputs[input_key]
|
query = inputs[input_key]
|
||||||
docs = self.retriever.get_relevant_documents(query)
|
docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
|
||||||
return {
|
return {
|
||||||
"chat_history": self.chat_memory.messages[-10:],
|
"chat_history": self.chat_memory.messages[-10:],
|
||||||
"relevant_context": docs,
|
"relevant_context": docs,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import re
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||||
@@ -222,14 +223,21 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def fetch_memories(
|
def fetch_memories(
|
||||||
self, observation: str, now: Optional[datetime] = None
|
self,
|
||||||
|
observation: str,
|
||||||
|
now: Optional[datetime] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Fetch related memories."""
|
"""Fetch related memories."""
|
||||||
if now is not None:
|
if now is not None:
|
||||||
with mock_now(now):
|
with mock_now(now):
|
||||||
return self.memory_retriever.get_relevant_documents(observation)
|
return self.memory_retriever.get_relevant_documents(
|
||||||
|
observation, callbacks=callbacks
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.memory_retriever.get_relevant_documents(observation)
|
return self.memory_retriever.get_relevant_documents(
|
||||||
|
observation, callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
||||||
content = []
|
content = []
|
||||||
@@ -260,7 +268,9 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
"""Input keys this memory class will load dynamically."""
|
"""Input keys this memory class will load dynamically."""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Return key-value pairs given the text input to the chain."""
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
queries = inputs.get(self.queries_key)
|
queries = inputs.get(self.queries_key)
|
||||||
now = inputs.get(self.now_key)
|
now = inputs.get(self.now_key)
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
|
|||||||
|
|
||||||
@inputs.setter
|
@inputs.setter
|
||||||
def inputs(self, inputs: Any) -> None:
|
def inputs(self, inputs: Any) -> None:
|
||||||
self._inputs = self.agent_executor.prep_inputs(inputs)
|
self._inputs = self.agent_executor.prep_inputs(inputs, callbacks=None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def callbacks(self) -> Callbacks:
|
def callbacks(self) -> Callbacks:
|
||||||
|
|||||||
@@ -217,7 +217,6 @@ class Chain(Serializable, ABC):
|
|||||||
A dict of named outputs. Should contain all outputs specified in
|
A dict of named outputs. Should contain all outputs specified in
|
||||||
`Chain.output_keys`.
|
`Chain.output_keys`.
|
||||||
"""
|
"""
|
||||||
inputs = self.prep_inputs(inputs)
|
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
callbacks,
|
callbacks,
|
||||||
self.callbacks,
|
self.callbacks,
|
||||||
@@ -232,6 +231,7 @@ class Chain(Serializable, ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
)
|
)
|
||||||
|
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
|
||||||
try:
|
try:
|
||||||
outputs = (
|
outputs = (
|
||||||
self._call(inputs, run_manager=run_manager)
|
self._call(inputs, run_manager=run_manager)
|
||||||
@@ -284,7 +284,6 @@ class Chain(Serializable, ABC):
|
|||||||
A dict of named outputs. Should contain all outputs specified in
|
A dict of named outputs. Should contain all outputs specified in
|
||||||
`Chain.output_keys`.
|
`Chain.output_keys`.
|
||||||
"""
|
"""
|
||||||
inputs = self.prep_inputs(inputs)
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks,
|
callbacks,
|
||||||
self.callbacks,
|
self.callbacks,
|
||||||
@@ -299,6 +298,7 @@ class Chain(Serializable, ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
)
|
)
|
||||||
|
inputs = self.prep_inputs(inputs, callbacks=run_manager.get_child())
|
||||||
try:
|
try:
|
||||||
outputs = (
|
outputs = (
|
||||||
await self._acall(inputs, run_manager=run_manager)
|
await self._acall(inputs, run_manager=run_manager)
|
||||||
@@ -342,7 +342,9 @@ class Chain(Serializable, ABC):
|
|||||||
else:
|
else:
|
||||||
return {**inputs, **outputs}
|
return {**inputs, **outputs}
|
||||||
|
|
||||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
def prep_inputs(
|
||||||
|
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Validate and prepare chain inputs, including adding inputs from memory.
|
"""Validate and prepare chain inputs, including adding inputs from memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -369,7 +371,9 @@ class Chain(Serializable, ABC):
|
|||||||
)
|
)
|
||||||
inputs = {list(_input_keys)[0]: inputs}
|
inputs = {list(_input_keys)[0]: inputs}
|
||||||
if self.memory is not None:
|
if self.memory is not None:
|
||||||
external_context = self.memory.load_memory_variables(inputs)
|
external_context = self.memory.load_memory_variables(
|
||||||
|
inputs, callbacks=callbacks
|
||||||
|
)
|
||||||
inputs = dict(inputs, **external_context)
|
inputs = dict(inputs, **external_context)
|
||||||
self._validate_inputs(inputs)
|
self._validate_inputs(inputs)
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -262,11 +262,13 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
return ["score", "reasoning"]
|
return ["score", "reasoning"]
|
||||||
return ["score"]
|
return ["score"]
|
||||||
|
|
||||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
def prep_inputs(
|
||||||
|
self, inputs: Union[Dict[str, Any], Any], *, callbacks: Callbacks
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Validate and prep inputs."""
|
"""Validate and prep inputs."""
|
||||||
if "reference" not in inputs:
|
if "reference" not in inputs:
|
||||||
inputs["reference"] = self._format_reference(inputs.get("reference"))
|
inputs["reference"] = self._format_reference(inputs.get("reference"))
|
||||||
return super().prep_inputs(inputs)
|
return super().prep_inputs(inputs, callbacks=callbacks)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
||||||
from langchain.vectorstores.base import VectorStoreRetriever
|
from langchain.vectorstores.base import VectorStoreRetriever
|
||||||
|
|
||||||
@@ -20,7 +21,9 @@ class AutoGPTMemory(BaseChatMemory):
|
|||||||
return get_prompt_input_key(inputs, self.memory_variables)
|
return get_prompt_input_key(inputs, self.memory_variables)
|
||||||
return self.input_key
|
return self.input_key
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
input_key = self._get_prompt_input_key(inputs)
|
input_key = self._get_prompt_input_key(inputs)
|
||||||
query = inputs[input_key]
|
query = inputs[input_key]
|
||||||
docs = self.retriever.get_relevant_documents(query)
|
docs = self.retriever.get_relevant_documents(query)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||||
from langchain.schema import BaseMemory, Document
|
from langchain.schema import BaseMemory, Document
|
||||||
@@ -260,7 +261,9 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
"""Input keys this memory class will load dynamically."""
|
"""Input keys this memory class will load dynamically."""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Return key-value pairs given the text input to the chain."""
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
queries = inputs.get(self.queries_key)
|
queries = inputs.get(self.queries_key)
|
||||||
now = inputs.get(self.now_key)
|
now = inputs.get(self.now_key)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
from langchain.schema.messages import get_buffer_string
|
from langchain.schema.messages import get_buffer_string
|
||||||
@@ -34,7 +35,9 @@ class ConversationBufferMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
|
||||||
@@ -66,7 +69,9 @@ class ConversationStringBufferMemory(BaseMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
@@ -26,7 +27,9 @@ class ConversationBufferWindowMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
|
|
||||||
buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else []
|
buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else []
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Set
|
|||||||
|
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.schema import BaseMemory
|
from langchain.schema import BaseMemory
|
||||||
|
|
||||||
@@ -54,13 +55,15 @@ class CombinedMemory(BaseMemory):
|
|||||||
|
|
||||||
return memory_variables
|
return memory_variables
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Load all vars from sub-memories."""
|
"""Load all vars from sub-memories."""
|
||||||
memory_data: Dict[str, Any] = {}
|
memory_data: Dict[str, Any] = {}
|
||||||
|
|
||||||
# Collect vars from all sub-memories
|
# Collect vars from all sub-memories
|
||||||
for memory in self.memories:
|
for memory in self.memories:
|
||||||
data = memory.load_memory_variables(inputs)
|
data = memory.load_memory_variables(inputs, callbacks=callbacks)
|
||||||
memory_data = {
|
memory_data = {
|
||||||
**memory_data,
|
**memory_data,
|
||||||
**data,
|
**data,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.prompt import (
|
from langchain.memory.prompt import (
|
||||||
@@ -285,7 +286,9 @@ class ConversationEntityMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return ["entities", self.chat_history_key]
|
return ["entities", self.chat_history_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Returns chat history and all generated entities with summaries if available,
|
Returns chat history and all generated entities with summaries if available,
|
||||||
and updates or clears the recent entity cache.
|
and updates or clears the recent entity cache.
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Type, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.graphs import NetworkxEntityGraph
|
from langchain.graphs import NetworkxEntityGraph
|
||||||
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
||||||
@@ -34,7 +35,9 @@ class ConversationKGMemory(BaseChatMemory):
|
|||||||
"""Number of previous utterances to include in the context."""
|
"""Number of previous utterances to include in the context."""
|
||||||
memory_key: str = "history" #: :meta private:
|
memory_key: str = "history" #: :meta private:
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
entities = self._get_current_entities(inputs)
|
entities = self._get_current_entities(inputs)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.schema.messages import get_buffer_string
|
from langchain.schema.messages import get_buffer_string
|
||||||
|
|
||||||
@@ -64,7 +65,9 @@ class MotorheadMemory(BaseChatMemory):
|
|||||||
if context and context != "NONE":
|
if context and context != "NONE":
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
if self.return_messages:
|
if self.return_messages:
|
||||||
return {self.memory_key: self.chat_memory.messages}
|
return {self.memory_key: self.chat_memory.messages}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.schema import BaseMemory
|
from langchain.schema import BaseMemory
|
||||||
|
|
||||||
|
|
||||||
@@ -13,9 +14,11 @@ class ReadOnlySharedMemory(BaseMemory):
|
|||||||
"""Return memory variables."""
|
"""Return memory variables."""
|
||||||
return self.memory.memory_variables
|
return self.memory.memory_variables
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Load memory variables from memory."""
|
"""Load memory variables from memory."""
|
||||||
return self.memory.load_memory_variables(inputs)
|
return self.memory.load_memory_variables(inputs, callbacks=callbacks)
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Nothing should be saved or changed"""
|
"""Nothing should be saved or changed"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.schema import BaseMemory
|
from langchain.schema import BaseMemory
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +15,9 @@ class SimpleMemory(BaseMemory):
|
|||||||
def memory_variables(self) -> List[str]:
|
def memory_variables(self) -> List[str]:
|
||||||
return list(self.memories.keys())
|
return list(self.memories.keys())
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
return self.memories
|
return self.memories
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Type
|
|||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||||
@@ -67,7 +68,9 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
if self.return_messages:
|
if self.return_messages:
|
||||||
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.summary import SummarizerMixin
|
from langchain.memory.summary import SummarizerMixin
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
@@ -26,7 +27,9 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
buffer = self.buffer
|
buffer = self.buffer
|
||||||
if self.moving_summary_buffer != "":
|
if self.moving_summary_buffer != "":
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
@@ -27,7 +28,9 @@ class ConversationTokenBufferMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
buffer: Any = self.buffer
|
buffer: Any = self.buffer
|
||||||
if self.return_messages:
|
if self.return_messages:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseMemory
|
from langchain.memory.chat_memory import BaseMemory
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
@@ -40,12 +41,12 @@ class VectorStoreRetrieverMemory(BaseMemory):
|
|||||||
return self.input_key
|
return self.input_key
|
||||||
|
|
||||||
def load_memory_variables(
|
def load_memory_variables(
|
||||||
self, inputs: Dict[str, Any]
|
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||||
) -> Dict[str, Union[List[Document], str]]:
|
) -> Dict[str, Union[List[Document], str]]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
input_key = self._get_prompt_input_key(inputs)
|
input_key = self._get_prompt_input_key(inputs)
|
||||||
query = inputs[input_key]
|
query = inputs[input_key]
|
||||||
docs = self.retriever.get_relevant_documents(query)
|
docs = self.retriever.get_relevant_documents(query, callbacks=callbacks)
|
||||||
result: Union[List[Document], str]
|
result: Union[List[Document], str]
|
||||||
if not self.return_docs:
|
if not self.return_docs:
|
||||||
result = "\n".join([doc.page_content for doc in docs])
|
result = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|||||||
@@ -94,7 +94,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
inputs = self.llm_chain.prep_inputs(
|
||||||
|
{"query": query}, callbacks=run_manager.get_child()
|
||||||
|
)
|
||||||
structured_query = cast(
|
structured_query = cast(
|
||||||
StructuredQuery,
|
StructuredQuery,
|
||||||
self.llm_chain.predict_and_parse(
|
self.llm_chain.predict_and_parse(
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import langchain
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
|
||||||
@@ -27,7 +28,7 @@ class BaseMemory(Serializable, ABC):
|
|||||||
def memory_variables(self) -> List[str]:
|
def memory_variables(self) -> List[str]:
|
||||||
return list(self.memories.keys())
|
return list(self.memories.keys())
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Dict[str, Any]:
|
||||||
return self.memories
|
return self.memories
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
@@ -48,7 +49,11 @@ class BaseMemory(Serializable, ABC):
|
|||||||
"""The string keys this memory class will add to chain inputs."""
|
"""The string keys this memory class will add to chain inputs."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
callbacks: "langchain.callbacks.manager.Callbacks" = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Return key-value pairs given the text input to the chain."""
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema import RUN_KEY, BaseMemory
|
from langchain.schema import RUN_KEY, BaseMemory
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
@@ -18,7 +18,7 @@ class FakeMemory(BaseMemory):
|
|||||||
return ["baz"]
|
return ["baz"]
|
||||||
|
|
||||||
def load_memory_variables(
|
def load_memory_variables(
|
||||||
self, inputs: Optional[Dict[str, Any]] = None
|
self, inputs: Optional[Dict[str, Any]] = None, callbacks: Callbacks = None
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Return baz variable."""
|
"""Return baz variable."""
|
||||||
return {"baz": "foo"}
|
return {"baz": "foo"}
|
||||||
|
|||||||
Reference in New Issue
Block a user