mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
langchain[patch]: Make more memory code handle community dependency as optional (#21199)
This commit is contained in:
parent
bd5d2c2674
commit
df49404794
@ -26,26 +26,9 @@
|
||||
|
||||
AIMessage, BaseMessage, HumanMessage
|
||||
""" # noqa: E501
|
||||
from langchain_community.chat_message_histories import (
|
||||
AstraDBChatMessageHistory,
|
||||
CassandraChatMessageHistory,
|
||||
ChatMessageHistory,
|
||||
CosmosDBChatMessageHistory,
|
||||
DynamoDBChatMessageHistory,
|
||||
ElasticsearchChatMessageHistory,
|
||||
FileChatMessageHistory,
|
||||
MomentoChatMessageHistory,
|
||||
MongoDBChatMessageHistory,
|
||||
PostgresChatMessageHistory,
|
||||
RedisChatMessageHistory,
|
||||
SingleStoreDBChatMessageHistory,
|
||||
SQLChatMessageHistory,
|
||||
StreamlitChatMessageHistory,
|
||||
UpstashRedisChatMessageHistory,
|
||||
XataChatMessageHistory,
|
||||
ZepChatMessageHistory,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain._api import create_importer
|
||||
from langchain.memory.buffer import (
|
||||
ConversationBufferMemory,
|
||||
ConversationStringBufferMemory,
|
||||
@ -59,15 +42,72 @@ from langchain.memory.entity import (
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
)
|
||||
from langchain.memory.kg import ConversationKGMemory
|
||||
from langchain.memory.motorhead_memory import MotorheadMemory
|
||||
from langchain.memory.readonly import ReadOnlySharedMemory
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
from langchain.memory.summary import ConversationSummaryMemory
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
||||
from langchain.memory.zep_memory import ZepMemory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import (
|
||||
AstraDBChatMessageHistory,
|
||||
CassandraChatMessageHistory,
|
||||
ChatMessageHistory,
|
||||
CosmosDBChatMessageHistory,
|
||||
DynamoDBChatMessageHistory,
|
||||
ElasticsearchChatMessageHistory,
|
||||
FileChatMessageHistory,
|
||||
MomentoChatMessageHistory,
|
||||
MongoDBChatMessageHistory,
|
||||
PostgresChatMessageHistory,
|
||||
RedisChatMessageHistory,
|
||||
SingleStoreDBChatMessageHistory,
|
||||
SQLChatMessageHistory,
|
||||
StreamlitChatMessageHistory,
|
||||
UpstashRedisChatMessageHistory,
|
||||
XataChatMessageHistory,
|
||||
ZepChatMessageHistory,
|
||||
)
|
||||
from langchain_community.memory.kg import ConversationKGMemory
|
||||
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||
from langchain_community.memory.zep_memory import ZepMemory
|
||||
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"MotorheadMemory": "langchain_community.memory.motorhead_memory",
|
||||
"ConversationKGMemory": "langchain_community.memory.kg",
|
||||
"ZepMemory": "langchain_community.memory.zep_memory",
|
||||
"AstraDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"CassandraChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"FileChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"MomentoChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"MongoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"PostgresChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"RedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"SQLChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"StreamlitChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"XataChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ZepChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AstraDBChatMessageHistory",
|
||||
|
@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
||||
from itertools import islice
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain_community.utilities.redis import get_client
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
@ -181,6 +180,14 @@ class RedisEntityStore(BaseEntityStore):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
from langchain_community.utilities.redis import get_client
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain_community.utilities.redis.get_client. "
|
||||
"Please install it with `pip install langchain-community`."
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||
except redis.exceptions.ConnectionError as error:
|
||||
|
@ -1,133 +1,23 @@
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_community.graphs import NetworkxEntityGraph
|
||||
from langchain_community.graphs.networkx_graph import (
|
||||
KnowledgeTriple,
|
||||
get_entities,
|
||||
parse_triples,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain._api import create_importer
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.memory.kg import ConversationKGMemory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ConversationKGMemory": "langchain_community.memory.kg"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
class ConversationKGMemory(BaseChatMemory):
|
||||
"""Knowledge graph conversation memory.
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
Integrates with external knowledge graph to store and retrieve
|
||||
information about knowledge triples in the conversation.
|
||||
"""
|
||||
|
||||
k: int = 2
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
|
||||
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
llm: BaseLanguageModel
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
"""Number of previous utterances to include in the context."""
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
entities = self._get_current_entities(inputs)
|
||||
|
||||
summary_strings = []
|
||||
for entity in entities:
|
||||
knowledge = self.kg.get_entity_knowledge(entity)
|
||||
if knowledge:
|
||||
summary = f"On {entity}: {'. '.join(knowledge)}."
|
||||
summary_strings.append(summary)
|
||||
context: Union[str, List]
|
||||
if not summary_strings:
|
||||
context = [] if self.return_messages else ""
|
||||
elif self.return_messages:
|
||||
context = [
|
||||
self.summary_message_cls(content=text) for text in summary_strings
|
||||
]
|
||||
else:
|
||||
context = "\n".join(summary_strings)
|
||||
|
||||
return {self.memory_key: context}
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
|
||||
"""Get the output key for the prompt."""
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
return list(outputs.keys())[0]
|
||||
return self.output_key
|
||||
|
||||
def get_current_entities(self, input_string: str) -> List[str]:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
||||
buffer_string = get_buffer_string(
|
||||
self.chat_memory.messages[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=input_string,
|
||||
)
|
||||
return get_entities(output)
|
||||
|
||||
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
|
||||
"""Get the current entities in the conversation."""
|
||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
||||
return self.get_current_entities(inputs[prompt_input_key])
|
||||
|
||||
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
|
||||
buffer_string = get_buffer_string(
|
||||
self.chat_memory.messages[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=input_string,
|
||||
verbose=True,
|
||||
)
|
||||
knowledge = parse_triples(output)
|
||||
return knowledge
|
||||
|
||||
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Get and update knowledge graph from the conversation history."""
|
||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
||||
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
|
||||
for triple in knowledge:
|
||||
self.kg.add_triple(triple)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self._get_and_update_kg(inputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.kg.clear()
|
||||
__all__ = [
|
||||
"ConversationKGMemory",
|
||||
]
|
||||
|
@ -1,3 +1,23 @@
|
||||
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
__all__ = ["MotorheadMemory"]
|
||||
from langchain._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"MotorheadMemory": "langchain_community.memory.motorhead_memory"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MotorheadMemory",
|
||||
]
|
||||
|
@ -1,3 +1,23 @@
|
||||
from langchain_community.memory.zep_memory import ZepMemory
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
__all__ = ["ZepMemory"]
|
||||
from langchain._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.memory.zep_memory import ZepMemory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ZepMemory": "langchain_community.memory.zep_memory"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ZepMemory",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user