langchain[patch]: Make more memory code handle community dependency as optional (#21199)

This commit is contained in:
Eugene Yurtsev 2024-05-02 11:05:26 -04:00 committed by GitHub
parent bd5d2c2674
commit df49404794
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 131 additions and 154 deletions

View File

@ -26,26 +26,9 @@
AIMessage, BaseMessage, HumanMessage AIMessage, BaseMessage, HumanMessage
""" # noqa: E501 """ # noqa: E501
from langchain_community.chat_message_histories import ( from typing import TYPE_CHECKING, Any
AstraDBChatMessageHistory,
CassandraChatMessageHistory,
ChatMessageHistory,
CosmosDBChatMessageHistory,
DynamoDBChatMessageHistory,
ElasticsearchChatMessageHistory,
FileChatMessageHistory,
MomentoChatMessageHistory,
MongoDBChatMessageHistory,
PostgresChatMessageHistory,
RedisChatMessageHistory,
SingleStoreDBChatMessageHistory,
SQLChatMessageHistory,
StreamlitChatMessageHistory,
UpstashRedisChatMessageHistory,
XataChatMessageHistory,
ZepChatMessageHistory,
)
from langchain._api import create_importer
from langchain.memory.buffer import ( from langchain.memory.buffer import (
ConversationBufferMemory, ConversationBufferMemory,
ConversationStringBufferMemory, ConversationStringBufferMemory,
@ -59,15 +42,72 @@ from langchain.memory.entity import (
SQLiteEntityStore, SQLiteEntityStore,
UpstashRedisEntityStore, UpstashRedisEntityStore,
) )
from langchain.memory.kg import ConversationKGMemory
from langchain.memory.motorhead_memory import MotorheadMemory
from langchain.memory.readonly import ReadOnlySharedMemory from langchain.memory.readonly import ReadOnlySharedMemory
from langchain.memory.simple import SimpleMemory from langchain.memory.simple import SimpleMemory
from langchain.memory.summary import ConversationSummaryMemory from langchain.memory.summary import ConversationSummaryMemory
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from langchain.memory.token_buffer import ConversationTokenBufferMemory from langchain.memory.token_buffer import ConversationTokenBufferMemory
from langchain.memory.vectorstore import VectorStoreRetrieverMemory from langchain.memory.vectorstore import VectorStoreRetrieverMemory
from langchain.memory.zep_memory import ZepMemory
if TYPE_CHECKING:
from langchain_community.chat_message_histories import (
AstraDBChatMessageHistory,
CassandraChatMessageHistory,
ChatMessageHistory,
CosmosDBChatMessageHistory,
DynamoDBChatMessageHistory,
ElasticsearchChatMessageHistory,
FileChatMessageHistory,
MomentoChatMessageHistory,
MongoDBChatMessageHistory,
PostgresChatMessageHistory,
RedisChatMessageHistory,
SingleStoreDBChatMessageHistory,
SQLChatMessageHistory,
StreamlitChatMessageHistory,
UpstashRedisChatMessageHistory,
XataChatMessageHistory,
ZepChatMessageHistory,
)
from langchain_community.memory.kg import ConversationKGMemory
from langchain_community.memory.motorhead_memory import MotorheadMemory
from langchain_community.memory.zep_memory import ZepMemory
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"MotorheadMemory": "langchain_community.memory.motorhead_memory",
"ConversationKGMemory": "langchain_community.memory.kg",
"ZepMemory": "langchain_community.memory.zep_memory",
"AstraDBChatMessageHistory": "langchain_community.chat_message_histories",
"CassandraChatMessageHistory": "langchain_community.chat_message_histories",
"ChatMessageHistory": "langchain_community.chat_message_histories",
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories",
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories",
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories",
"FileChatMessageHistory": "langchain_community.chat_message_histories",
"MomentoChatMessageHistory": "langchain_community.chat_message_histories",
"MongoDBChatMessageHistory": "langchain_community.chat_message_histories",
"PostgresChatMessageHistory": "langchain_community.chat_message_histories",
"RedisChatMessageHistory": "langchain_community.chat_message_histories",
"SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories",
"SQLChatMessageHistory": "langchain_community.chat_message_histories",
"StreamlitChatMessageHistory": "langchain_community.chat_message_histories",
"UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories",
"XataChatMessageHistory": "langchain_community.chat_message_histories",
"ZepChatMessageHistory": "langchain_community.chat_message_histories",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [ __all__ = [
"AstraDBChatMessageHistory", "AstraDBChatMessageHistory",

View File

@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
from itertools import islice from itertools import islice
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
from langchain_community.utilities.redis import get_client
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
@ -181,6 +180,14 @@ class RedisEntityStore(BaseEntityStore):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
try:
from langchain_community.utilities.redis import get_client
except ImportError:
raise ImportError(
"Could not import langchain_community.utilities.redis.get_client. "
"Please install it with `pip install langchain-community`."
)
try: try:
self.redis_client = get_client(redis_url=url, decode_responses=True) self.redis_client = get_client(redis_url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error: except redis.exceptions.ConnectionError as error:

View File

@ -1,133 +1,23 @@
from typing import Any, Dict, List, Type, Union from typing import TYPE_CHECKING, Any
from langchain_community.graphs import NetworkxEntityGraph from langchain._api import create_importer
from langchain_community.graphs.networkx_graph import (
KnowledgeTriple,
get_entities,
parse_triples,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.llm import LLMChain if TYPE_CHECKING:
from langchain.memory.chat_memory import BaseChatMemory from langchain_community.memory.kg import ConversationKGMemory
from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT, # Create a way to dynamically look up deprecated imports.
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, # Used to consolidate logic for raising deprecation warnings and
) # handling optional imports.
from langchain.memory.utils import get_prompt_input_key DEPRECATED_LOOKUP = {"ConversationKGMemory": "langchain_community.memory.kg"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
class ConversationKGMemory(BaseChatMemory): def __getattr__(name: str) -> Any:
"""Knowledge graph conversation memory. """Look up attributes dynamically."""
return _import_attribute(name)
Integrates with external knowledge graph to store and retrieve
information about knowledge triples in the conversation.
"""
k: int = 2 __all__ = [
human_prefix: str = "Human" "ConversationKGMemory",
ai_prefix: str = "AI" ]
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLanguageModel
summary_message_cls: Type[BaseMessage] = SystemMessage
"""Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
entities = self._get_current_entities(inputs)
summary_strings = []
for entity in entities:
knowledge = self.kg.get_entity_knowledge(entity)
if knowledge:
summary = f"On {entity}: {'. '.join(knowledge)}."
summary_strings.append(summary)
context: Union[str, List]
if not summary_strings:
context = [] if self.return_messages else ""
elif self.return_messages:
context = [
self.summary_message_cls(content=text) for text in summary_strings
]
else:
context = "\n".join(summary_strings)
return {self.memory_key: context}
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
"""Get the output key for the prompt."""
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
return list(outputs.keys())[0]
return self.output_key
def get_current_entities(self, input_string: str) -> List[str]:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
)
return get_entities(output)
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""Get the current entities in the conversation."""
prompt_input_key = self._get_prompt_input_key(inputs)
return self.get_current_entities(inputs[prompt_input_key])
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
verbose=True,
)
knowledge = parse_triples(output)
return knowledge
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""Get and update knowledge graph from the conversation history."""
prompt_input_key = self._get_prompt_input_key(inputs)
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
for triple in knowledge:
self.kg.add_triple(triple)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self._get_and_update_kg(inputs)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.kg.clear()

View File

@ -1,3 +1,23 @@
from langchain_community.memory.motorhead_memory import MotorheadMemory from typing import TYPE_CHECKING, Any
__all__ = ["MotorheadMemory"] from langchain._api import create_importer
if TYPE_CHECKING:
from langchain_community.memory.motorhead_memory import MotorheadMemory
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"MotorheadMemory": "langchain_community.memory.motorhead_memory"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"MotorheadMemory",
]

View File

@ -1,3 +1,23 @@
from langchain_community.memory.zep_memory import ZepMemory from typing import TYPE_CHECKING, Any
__all__ = ["ZepMemory"] from langchain._api import create_importer
if TYPE_CHECKING:
from langchain_community.memory.zep_memory import ZepMemory
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ZepMemory": "langchain_community.memory.zep_memory"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"ZepMemory",
]