mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 19:03:25 +00:00
langchain[patch]: Make more memory code handle community dependency as optional (#21199)
This commit is contained in:
parent
bd5d2c2674
commit
df49404794
@ -26,26 +26,9 @@
|
|||||||
|
|
||||||
AIMessage, BaseMessage, HumanMessage
|
AIMessage, BaseMessage, HumanMessage
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
from langchain_community.chat_message_histories import (
|
from typing import TYPE_CHECKING, Any
|
||||||
AstraDBChatMessageHistory,
|
|
||||||
CassandraChatMessageHistory,
|
|
||||||
ChatMessageHistory,
|
|
||||||
CosmosDBChatMessageHistory,
|
|
||||||
DynamoDBChatMessageHistory,
|
|
||||||
ElasticsearchChatMessageHistory,
|
|
||||||
FileChatMessageHistory,
|
|
||||||
MomentoChatMessageHistory,
|
|
||||||
MongoDBChatMessageHistory,
|
|
||||||
PostgresChatMessageHistory,
|
|
||||||
RedisChatMessageHistory,
|
|
||||||
SingleStoreDBChatMessageHistory,
|
|
||||||
SQLChatMessageHistory,
|
|
||||||
StreamlitChatMessageHistory,
|
|
||||||
UpstashRedisChatMessageHistory,
|
|
||||||
XataChatMessageHistory,
|
|
||||||
ZepChatMessageHistory,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from langchain._api import create_importer
|
||||||
from langchain.memory.buffer import (
|
from langchain.memory.buffer import (
|
||||||
ConversationBufferMemory,
|
ConversationBufferMemory,
|
||||||
ConversationStringBufferMemory,
|
ConversationStringBufferMemory,
|
||||||
@ -59,15 +42,72 @@ from langchain.memory.entity import (
|
|||||||
SQLiteEntityStore,
|
SQLiteEntityStore,
|
||||||
UpstashRedisEntityStore,
|
UpstashRedisEntityStore,
|
||||||
)
|
)
|
||||||
from langchain.memory.kg import ConversationKGMemory
|
|
||||||
from langchain.memory.motorhead_memory import MotorheadMemory
|
|
||||||
from langchain.memory.readonly import ReadOnlySharedMemory
|
from langchain.memory.readonly import ReadOnlySharedMemory
|
||||||
from langchain.memory.simple import SimpleMemory
|
from langchain.memory.simple import SimpleMemory
|
||||||
from langchain.memory.summary import ConversationSummaryMemory
|
from langchain.memory.summary import ConversationSummaryMemory
|
||||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||||
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
||||||
from langchain.memory.zep_memory import ZepMemory
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_community.chat_message_histories import (
|
||||||
|
AstraDBChatMessageHistory,
|
||||||
|
CassandraChatMessageHistory,
|
||||||
|
ChatMessageHistory,
|
||||||
|
CosmosDBChatMessageHistory,
|
||||||
|
DynamoDBChatMessageHistory,
|
||||||
|
ElasticsearchChatMessageHistory,
|
||||||
|
FileChatMessageHistory,
|
||||||
|
MomentoChatMessageHistory,
|
||||||
|
MongoDBChatMessageHistory,
|
||||||
|
PostgresChatMessageHistory,
|
||||||
|
RedisChatMessageHistory,
|
||||||
|
SingleStoreDBChatMessageHistory,
|
||||||
|
SQLChatMessageHistory,
|
||||||
|
StreamlitChatMessageHistory,
|
||||||
|
UpstashRedisChatMessageHistory,
|
||||||
|
XataChatMessageHistory,
|
||||||
|
ZepChatMessageHistory,
|
||||||
|
)
|
||||||
|
from langchain_community.memory.kg import ConversationKGMemory
|
||||||
|
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||||
|
from langchain_community.memory.zep_memory import ZepMemory
|
||||||
|
|
||||||
|
|
||||||
|
# Create a way to dynamically look up deprecated imports.
|
||||||
|
# Used to consolidate logic for raising deprecation warnings and
|
||||||
|
# handling optional imports.
|
||||||
|
DEPRECATED_LOOKUP = {
|
||||||
|
"MotorheadMemory": "langchain_community.memory.motorhead_memory",
|
||||||
|
"ConversationKGMemory": "langchain_community.memory.kg",
|
||||||
|
"ZepMemory": "langchain_community.memory.zep_memory",
|
||||||
|
"AstraDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"CassandraChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"ChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"FileChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"MomentoChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"MongoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"PostgresChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"RedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"SQLChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"StreamlitChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"XataChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
"ZepChatMessageHistory": "langchain_community.chat_message_histories",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Look up attributes dynamically."""
|
||||||
|
return _import_attribute(name)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AstraDBChatMessageHistory",
|
"AstraDBChatMessageHistory",
|
||||||
|
@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from langchain_community.utilities.redis import get_client
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
@ -181,6 +180,14 @@ class RedisEntityStore(BaseEntityStore):
|
|||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_community.utilities.redis import get_client
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import langchain_community.utilities.redis.get_client. "
|
||||||
|
"Please install it with `pip install langchain-community`."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||||
except redis.exceptions.ConnectionError as error:
|
except redis.exceptions.ConnectionError as error:
|
||||||
|
@ -1,133 +1,23 @@
|
|||||||
from typing import Any, Dict, List, Type, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain_community.graphs import NetworkxEntityGraph
|
from langchain._api import create_importer
|
||||||
from langchain_community.graphs.networkx_graph import (
|
|
||||||
KnowledgeTriple,
|
|
||||||
get_entities,
|
|
||||||
parse_triples,
|
|
||||||
)
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
if TYPE_CHECKING:
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain_community.memory.kg import ConversationKGMemory
|
||||||
from langchain.memory.prompt import (
|
|
||||||
ENTITY_EXTRACTION_PROMPT,
|
# Create a way to dynamically look up deprecated imports.
|
||||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
# Used to consolidate logic for raising deprecation warnings and
|
||||||
)
|
# handling optional imports.
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
DEPRECATED_LOOKUP = {"ConversationKGMemory": "langchain_community.memory.kg"}
|
||||||
|
|
||||||
|
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||||
|
|
||||||
|
|
||||||
class ConversationKGMemory(BaseChatMemory):
|
def __getattr__(name: str) -> Any:
|
||||||
"""Knowledge graph conversation memory.
|
"""Look up attributes dynamically."""
|
||||||
|
return _import_attribute(name)
|
||||||
|
|
||||||
Integrates with external knowledge graph to store and retrieve
|
|
||||||
information about knowledge triples in the conversation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
k: int = 2
|
__all__ = [
|
||||||
human_prefix: str = "Human"
|
"ConversationKGMemory",
|
||||||
ai_prefix: str = "AI"
|
]
|
||||||
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
|
|
||||||
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
|
||||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
|
||||||
llm: BaseLanguageModel
|
|
||||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
|
||||||
"""Number of previous utterances to include in the context."""
|
|
||||||
memory_key: str = "history" #: :meta private:
|
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Return history buffer."""
|
|
||||||
entities = self._get_current_entities(inputs)
|
|
||||||
|
|
||||||
summary_strings = []
|
|
||||||
for entity in entities:
|
|
||||||
knowledge = self.kg.get_entity_knowledge(entity)
|
|
||||||
if knowledge:
|
|
||||||
summary = f"On {entity}: {'. '.join(knowledge)}."
|
|
||||||
summary_strings.append(summary)
|
|
||||||
context: Union[str, List]
|
|
||||||
if not summary_strings:
|
|
||||||
context = [] if self.return_messages else ""
|
|
||||||
elif self.return_messages:
|
|
||||||
context = [
|
|
||||||
self.summary_message_cls(content=text) for text in summary_strings
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
context = "\n".join(summary_strings)
|
|
||||||
|
|
||||||
return {self.memory_key: context}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_variables(self) -> List[str]:
|
|
||||||
"""Will always return list of memory variables.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.memory_key]
|
|
||||||
|
|
||||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
|
||||||
"""Get the input key for the prompt."""
|
|
||||||
if self.input_key is None:
|
|
||||||
return get_prompt_input_key(inputs, self.memory_variables)
|
|
||||||
return self.input_key
|
|
||||||
|
|
||||||
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
|
|
||||||
"""Get the output key for the prompt."""
|
|
||||||
if self.output_key is None:
|
|
||||||
if len(outputs) != 1:
|
|
||||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
|
||||||
return list(outputs.keys())[0]
|
|
||||||
return self.output_key
|
|
||||||
|
|
||||||
def get_current_entities(self, input_string: str) -> List[str]:
|
|
||||||
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
|
||||||
buffer_string = get_buffer_string(
|
|
||||||
self.chat_memory.messages[-self.k * 2 :],
|
|
||||||
human_prefix=self.human_prefix,
|
|
||||||
ai_prefix=self.ai_prefix,
|
|
||||||
)
|
|
||||||
output = chain.predict(
|
|
||||||
history=buffer_string,
|
|
||||||
input=input_string,
|
|
||||||
)
|
|
||||||
return get_entities(output)
|
|
||||||
|
|
||||||
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
|
|
||||||
"""Get the current entities in the conversation."""
|
|
||||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
|
||||||
return self.get_current_entities(inputs[prompt_input_key])
|
|
||||||
|
|
||||||
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
|
|
||||||
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
|
|
||||||
buffer_string = get_buffer_string(
|
|
||||||
self.chat_memory.messages[-self.k * 2 :],
|
|
||||||
human_prefix=self.human_prefix,
|
|
||||||
ai_prefix=self.ai_prefix,
|
|
||||||
)
|
|
||||||
output = chain.predict(
|
|
||||||
history=buffer_string,
|
|
||||||
input=input_string,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
knowledge = parse_triples(output)
|
|
||||||
return knowledge
|
|
||||||
|
|
||||||
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
|
|
||||||
"""Get and update knowledge graph from the conversation history."""
|
|
||||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
|
||||||
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
|
|
||||||
for triple in knowledge:
|
|
||||||
self.kg.add_triple(triple)
|
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
||||||
"""Save context from this conversation to buffer."""
|
|
||||||
super().save_context(inputs, outputs)
|
|
||||||
self._get_and_update_kg(inputs)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Clear memory contents."""
|
|
||||||
super().clear()
|
|
||||||
self.kg.clear()
|
|
||||||
|
@ -1,3 +1,23 @@
|
|||||||
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
__all__ = ["MotorheadMemory"]
|
from langchain._api import create_importer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||||
|
|
||||||
|
# Create a way to dynamically look up deprecated imports.
|
||||||
|
# Used to consolidate logic for raising deprecation warnings and
|
||||||
|
# handling optional imports.
|
||||||
|
DEPRECATED_LOOKUP = {"MotorheadMemory": "langchain_community.memory.motorhead_memory"}
|
||||||
|
|
||||||
|
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Look up attributes dynamically."""
|
||||||
|
return _import_attribute(name)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MotorheadMemory",
|
||||||
|
]
|
||||||
|
@ -1,3 +1,23 @@
|
|||||||
from langchain_community.memory.zep_memory import ZepMemory
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
__all__ = ["ZepMemory"]
|
from langchain._api import create_importer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_community.memory.zep_memory import ZepMemory
|
||||||
|
|
||||||
|
# Create a way to dynamically look up deprecated imports.
|
||||||
|
# Used to consolidate logic for raising deprecation warnings and
|
||||||
|
# handling optional imports.
|
||||||
|
DEPRECATED_LOOKUP = {"ZepMemory": "langchain_community.memory.zep_memory"}
|
||||||
|
|
||||||
|
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Look up attributes dynamically."""
|
||||||
|
return _import_attribute(name)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ZepMemory",
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user