langchain[patch]: Make more memory code handle community dependency as optional (#21199)

This commit is contained in:
Eugene Yurtsev 2024-05-02 11:05:26 -04:00 committed by GitHub
parent bd5d2c2674
commit df49404794
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 131 additions and 154 deletions

View File

@ -26,26 +26,9 @@
AIMessage, BaseMessage, HumanMessage
""" # noqa: E501
from langchain_community.chat_message_histories import (
AstraDBChatMessageHistory,
CassandraChatMessageHistory,
ChatMessageHistory,
CosmosDBChatMessageHistory,
DynamoDBChatMessageHistory,
ElasticsearchChatMessageHistory,
FileChatMessageHistory,
MomentoChatMessageHistory,
MongoDBChatMessageHistory,
PostgresChatMessageHistory,
RedisChatMessageHistory,
SingleStoreDBChatMessageHistory,
SQLChatMessageHistory,
StreamlitChatMessageHistory,
UpstashRedisChatMessageHistory,
XataChatMessageHistory,
ZepChatMessageHistory,
)
from typing import TYPE_CHECKING, Any
from langchain._api import create_importer
from langchain.memory.buffer import (
ConversationBufferMemory,
ConversationStringBufferMemory,
@ -59,15 +42,72 @@ from langchain.memory.entity import (
SQLiteEntityStore,
UpstashRedisEntityStore,
)
from langchain.memory.kg import ConversationKGMemory
from langchain.memory.motorhead_memory import MotorheadMemory
from langchain.memory.readonly import ReadOnlySharedMemory
from langchain.memory.simple import SimpleMemory
from langchain.memory.summary import ConversationSummaryMemory
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from langchain.memory.token_buffer import ConversationTokenBufferMemory
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
from langchain.memory.zep_memory import ZepMemory
if TYPE_CHECKING:
from langchain_community.chat_message_histories import (
AstraDBChatMessageHistory,
CassandraChatMessageHistory,
ChatMessageHistory,
CosmosDBChatMessageHistory,
DynamoDBChatMessageHistory,
ElasticsearchChatMessageHistory,
FileChatMessageHistory,
MomentoChatMessageHistory,
MongoDBChatMessageHistory,
PostgresChatMessageHistory,
RedisChatMessageHistory,
SingleStoreDBChatMessageHistory,
SQLChatMessageHistory,
StreamlitChatMessageHistory,
UpstashRedisChatMessageHistory,
XataChatMessageHistory,
ZepChatMessageHistory,
)
from langchain_community.memory.kg import ConversationKGMemory
from langchain_community.memory.motorhead_memory import MotorheadMemory
from langchain_community.memory.zep_memory import ZepMemory
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"MotorheadMemory": "langchain_community.memory.motorhead_memory",
"ConversationKGMemory": "langchain_community.memory.kg",
"ZepMemory": "langchain_community.memory.zep_memory",
"AstraDBChatMessageHistory": "langchain_community.chat_message_histories",
"CassandraChatMessageHistory": "langchain_community.chat_message_histories",
"ChatMessageHistory": "langchain_community.chat_message_histories",
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories",
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories",
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories",
"FileChatMessageHistory": "langchain_community.chat_message_histories",
"MomentoChatMessageHistory": "langchain_community.chat_message_histories",
"MongoDBChatMessageHistory": "langchain_community.chat_message_histories",
"PostgresChatMessageHistory": "langchain_community.chat_message_histories",
"RedisChatMessageHistory": "langchain_community.chat_message_histories",
"SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories",
"SQLChatMessageHistory": "langchain_community.chat_message_histories",
"StreamlitChatMessageHistory": "langchain_community.chat_message_histories",
"UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories",
"XataChatMessageHistory": "langchain_community.chat_message_histories",
"ZepChatMessageHistory": "langchain_community.chat_message_histories",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"AstraDBChatMessageHistory",

View File

@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional
from langchain_community.utilities.redis import get_client
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
@ -181,6 +180,14 @@ class RedisEntityStore(BaseEntityStore):
super().__init__(*args, **kwargs)
try:
from langchain_community.utilities.redis import get_client
except ImportError:
raise ImportError(
"Could not import langchain_community.utilities.redis.get_client. "
"Please install it with `pip install langchain-community`."
)
try:
self.redis_client = get_client(redis_url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:

View File

@ -1,133 +1,23 @@
from typing import Any, Dict, List, Type, Union
from typing import TYPE_CHECKING, Any
from langchain_community.graphs import NetworkxEntityGraph
from langchain_community.graphs.networkx_graph import (
KnowledgeTriple,
get_entities,
parse_triples,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain._api import create_importer
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
if TYPE_CHECKING:
from langchain_community.memory.kg import ConversationKGMemory
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ConversationKGMemory": "langchain_community.memory.kg"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
class ConversationKGMemory(BaseChatMemory):
"""Knowledge graph conversation memory.
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
Integrates with external knowledge graph to store and retrieve
information about knowledge triples in the conversation.
"""
k: int = 2
human_prefix: str = "Human"
ai_prefix: str = "AI"
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLanguageModel
summary_message_cls: Type[BaseMessage] = SystemMessage
"""Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
entities = self._get_current_entities(inputs)
summary_strings = []
for entity in entities:
knowledge = self.kg.get_entity_knowledge(entity)
if knowledge:
summary = f"On {entity}: {'. '.join(knowledge)}."
summary_strings.append(summary)
context: Union[str, List]
if not summary_strings:
context = [] if self.return_messages else ""
elif self.return_messages:
context = [
self.summary_message_cls(content=text) for text in summary_strings
]
else:
context = "\n".join(summary_strings)
return {self.memory_key: context}
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
"""Get the output key for the prompt."""
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
return list(outputs.keys())[0]
return self.output_key
def get_current_entities(self, input_string: str) -> List[str]:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
)
return get_entities(output)
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""Get the current entities in the conversation."""
prompt_input_key = self._get_prompt_input_key(inputs)
return self.get_current_entities(inputs[prompt_input_key])
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
verbose=True,
)
knowledge = parse_triples(output)
return knowledge
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""Get and update knowledge graph from the conversation history."""
prompt_input_key = self._get_prompt_input_key(inputs)
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
for triple in knowledge:
self.kg.add_triple(triple)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self._get_and_update_kg(inputs)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.kg.clear()
__all__ = [
"ConversationKGMemory",
]

View File

@ -1,3 +1,23 @@
from langchain_community.memory.motorhead_memory import MotorheadMemory
from typing import TYPE_CHECKING, Any
__all__ = ["MotorheadMemory"]
from langchain._api import create_importer
if TYPE_CHECKING:
from langchain_community.memory.motorhead_memory import MotorheadMemory
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"MotorheadMemory": "langchain_community.memory.motorhead_memory"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"MotorheadMemory",
]

View File

@ -1,3 +1,23 @@
from langchain_community.memory.zep_memory import ZepMemory
from typing import TYPE_CHECKING, Any
__all__ = ["ZepMemory"]
from langchain._api import create_importer
if TYPE_CHECKING:
from langchain_community.memory.zep_memory import ZepMemory
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ZepMemory": "langchain_community.memory.zep_memory"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"ZepMemory",
]