mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
expose memory key name (#808)
This commit is contained in:
parent
7728a848d0
commit
f46f1d28af
@ -230,12 +230,12 @@ class ConversationEntityMemory(Memory, BaseModel):
|
||||
llm: BaseLLM
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
|
||||
memory_keys: List[str] = ["entities", "history"] #: :meta private:
|
||||
output_key: Optional[str] = None
|
||||
input_key: Optional[str] = None
|
||||
store: Dict[str, Optional[str]] = {}
|
||||
entity_cache: List[str] = []
|
||||
k: int = 3
|
||||
chat_history_key: str = "history"
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
@ -243,7 +243,7 @@ class ConversationEntityMemory(Memory, BaseModel):
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return ["entities", "history"]
|
||||
return ["entities", self.chat_history_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
@ -265,7 +265,7 @@ class ConversationEntityMemory(Memory, BaseModel):
|
||||
entity_summaries[entity] = self.store.get(entity, "")
|
||||
self.entity_cache = entities
|
||||
return {
|
||||
"history": "\n".join(self.buffer[-self.k :]),
|
||||
self.chat_history_key: "\n".join(self.buffer[-self.k :]),
|
||||
"entities": entity_summaries,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user