mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-16 18:24:31 +00:00
Compare commits
13 Commits
langchain-
...
eugene/mov
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9edffaaed2 | ||
|
|
d09f8eebff | ||
|
|
1b19f839f9 | ||
|
|
e8d99c9620 | ||
|
|
c52a84c5a3 | ||
|
|
1ac61323d3 | ||
|
|
f82b2f4a6f | ||
|
|
3755822a2d | ||
|
|
017ae731d4 | ||
|
|
9ac0b0026b | ||
|
|
59fbe77510 | ||
|
|
8aea083bf3 | ||
|
|
aaf376a681 |
17
libs/community/langchain_community/memory/__init__.py
Normal file
17
libs/community/langchain_community/memory/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from langchain_community.memory.entity import (
|
||||
RedisEntityStore,
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
)
|
||||
from langchain_community.memory.kg import ConversationKGMemory
|
||||
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||
from langchain_community.memory.zep_memory import ZepMemory
|
||||
|
||||
__all__ = [
|
||||
"ConversationKGMemory",
|
||||
"RedisEntityStore",
|
||||
"UpstashRedisEntityStore",
|
||||
"SQLiteEntityStore",
|
||||
"MotorheadMemory",
|
||||
"ZepMemory",
|
||||
]
|
||||
268
libs/community/langchain_community/memory/entity.py
Normal file
268
libs/community/langchain_community/memory/entity.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import logging
|
||||
from itertools import islice
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from langchain_core.legacy.memory.entity import BaseEntityStore
|
||||
|
||||
from langchain_community.utilities.redis import get_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpstashRedisEntityStore(BaseEntityStore):
|
||||
"""Upstash Redis backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
def scan_and_delete(cursor: int) -> int:
|
||||
cursor, keys_to_delete = self.redis_client.scan(
|
||||
cursor, f"{self.full_key_prefix}:*"
|
||||
)
|
||||
self.redis_client.delete(*keys_to_delete)
|
||||
return cursor
|
||||
|
||||
cursor = scan_and_delete(0)
|
||||
while cursor != 0:
|
||||
scan_and_delete(cursor)
|
||||
|
||||
|
||||
class RedisEntityStore(BaseEntityStore):
|
||||
"""Redis-backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
redis_client: Any
|
||||
session_id: str = "default"
|
||||
key_prefix: str = "memory_store"
|
||||
ttl: Optional[int] = 60 * 60 * 24
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "redis://localhost:6379/0",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||
except redis.exceptions.ConnectionError as error:
|
||||
logger.error(error)
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
# iterate a list in batches of size batch_size
|
||||
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
||||
iterator = iter(iterable)
|
||||
while batch := list(islice(iterator, batch_size)):
|
||||
yield batch
|
||||
|
||||
for keybatch in batched(
|
||||
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
|
||||
):
|
||||
self.redis_client.delete(*keybatch)
|
||||
|
||||
|
||||
class SQLiteEntityStore(BaseEntityStore):
|
||||
"""SQLite-backed Entity store"""
|
||||
|
||||
session_id: str = "default"
|
||||
table_name: str = "memory_store"
|
||||
conn: Any = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
db_file: str = "entities.db",
|
||||
table_name: str = "memory_store",
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import sqlite3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sqlite3 python package. "
|
||||
"Please install it with `pip install sqlite3`."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conn = sqlite3.connect(db_file)
|
||||
self.session_id = session_id
|
||||
self.table_name = table_name
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
@property
|
||||
def full_table_name(self) -> str:
|
||||
return f"{self.table_name}_{self.session_id}"
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(create_table_query)
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
query = f"""
|
||||
SELECT value
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
if result is not None:
|
||||
value = result[0]
|
||||
return value
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
query = f"""
|
||||
INSERT OR REPLACE INTO {self.full_table_name} (key, value)
|
||||
VALUES (?, ?)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key, value))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key,))
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
query = f"""
|
||||
SELECT 1
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
def clear(self) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query)
|
||||
133
libs/community/langchain_community/memory/kg.py
Normal file
133
libs/community/langchain_community/memory/kg.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.legacy.chains.llm import LLMChain
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.legacy.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||
)
|
||||
from langchain_core.legacy.memory.utils import get_prompt_input_key
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_community.graphs import NetworkxEntityGraph
|
||||
from langchain_community.graphs.networkx_graph import (
|
||||
KnowledgeTriple,
|
||||
get_entities,
|
||||
parse_triples,
|
||||
)
|
||||
|
||||
|
||||
class ConversationKGMemory(BaseChatMemory):
|
||||
"""Knowledge graph conversation memory.
|
||||
|
||||
Integrates with external knowledge graph to store and retrieve
|
||||
information about knowledge triples in the conversation.
|
||||
"""
|
||||
|
||||
k: int = 2
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
|
||||
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
llm: BaseLanguageModel
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
"""Number of previous utterances to include in the context."""
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
entities = self._get_current_entities(inputs)
|
||||
|
||||
summary_strings = []
|
||||
for entity in entities:
|
||||
knowledge = self.kg.get_entity_knowledge(entity)
|
||||
if knowledge:
|
||||
summary = f"On {entity}: {'. '.join(knowledge)}."
|
||||
summary_strings.append(summary)
|
||||
context: Union[str, List]
|
||||
if not summary_strings:
|
||||
context = [] if self.return_messages else ""
|
||||
elif self.return_messages:
|
||||
context = [
|
||||
self.summary_message_cls(content=text) for text in summary_strings
|
||||
]
|
||||
else:
|
||||
context = "\n".join(summary_strings)
|
||||
|
||||
return {self.memory_key: context}
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
|
||||
"""Get the output key for the prompt."""
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
return list(outputs.keys())[0]
|
||||
return self.output_key
|
||||
|
||||
def get_current_entities(self, input_string: str) -> List[str]:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
||||
buffer_string = get_buffer_string(
|
||||
self.chat_memory.messages[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=input_string,
|
||||
)
|
||||
return get_entities(output)
|
||||
|
||||
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
|
||||
"""Get the current entities in the conversation."""
|
||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
||||
return self.get_current_entities(inputs[prompt_input_key])
|
||||
|
||||
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
|
||||
buffer_string = get_buffer_string(
|
||||
self.chat_memory.messages[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=input_string,
|
||||
verbose=True,
|
||||
)
|
||||
knowledge = parse_triples(output)
|
||||
return knowledge
|
||||
|
||||
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Get and update knowledge graph from the conversation history."""
|
||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
||||
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
|
||||
for triple in knowledge:
|
||||
self.kg.add_triple(triple)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self._get_and_update_kg(inputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.kg.clear()
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.messages import get_buffer_string
|
||||
|
||||
MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
|
||||
|
||||
|
||||
class MotorheadMemory(BaseChatMemory):
|
||||
"""Chat message memory backed by Motorhead service."""
|
||||
|
||||
url: str = MANAGED_URL
|
||||
timeout: int = 3000
|
||||
memory_key: str = "history"
|
||||
session_id: str
|
||||
context: Optional[str] = None
|
||||
|
||||
# Managed Params
|
||||
api_key: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
|
||||
def __get_headers(self) -> Dict[str, str]:
|
||||
is_managed = self.url == MANAGED_URL
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if is_managed and not (self.api_key and self.client_id):
|
||||
raise ValueError(
|
||||
"""
|
||||
You must provide an API key or a client ID to use the managed
|
||||
version of Motorhead. Visit https://getmetal.io for more information.
|
||||
"""
|
||||
)
|
||||
|
||||
if is_managed and self.api_key and self.client_id:
|
||||
headers["x-metal-api-key"] = self.api_key
|
||||
headers["x-metal-client-id"] = self.client_id
|
||||
|
||||
return headers
|
||||
|
||||
async def init(self) -> None:
|
||||
res = requests.get(
|
||||
f"{self.url}/sessions/{self.session_id}/memory",
|
||||
timeout=self.timeout,
|
||||
headers=self.__get_headers(),
|
||||
)
|
||||
res_data = res.json()
|
||||
res_data = res_data.get("data", res_data) # Handle Managed Version
|
||||
|
||||
messages = res_data.get("messages", [])
|
||||
context = res_data.get("context", "NONE")
|
||||
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "AI":
|
||||
self.chat_memory.add_ai_message(message["content"])
|
||||
else:
|
||||
self.chat_memory.add_user_message(message["content"])
|
||||
|
||||
if context and context != "NONE":
|
||||
self.context = context
|
||||
|
||||
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if self.return_messages:
|
||||
return {self.memory_key: self.chat_memory.messages}
|
||||
else:
|
||||
return {self.memory_key: get_buffer_string(self.chat_memory.messages)}
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return [self.memory_key]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
requests.post(
|
||||
f"{self.url}/sessions/{self.session_id}/memory",
|
||||
timeout=self.timeout,
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "Human", "content": f"{input_str}"},
|
||||
{"role": "AI", "content": f"{output_str}"},
|
||||
]
|
||||
},
|
||||
headers=self.__get_headers(),
|
||||
)
|
||||
super().save_context(inputs, outputs)
|
||||
|
||||
def delete_session(self) -> None:
|
||||
"""Delete a session"""
|
||||
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")
|
||||
125
libs/community/langchain_community/memory/zep_memory.py
Normal file
125
libs/community/langchain_community/memory/zep_memory.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.legacy.memory import ConversationBufferMemory
|
||||
|
||||
from langchain_community.chat_message_histories import ZepChatMessageHistory
|
||||
|
||||
|
||||
class ZepMemory(ConversationBufferMemory):
|
||||
"""Persist your chain history to the Zep MemoryStore.
|
||||
|
||||
The number of messages returned by Zep and when the Zep server summarizes chat
|
||||
histories is configurable. See the Zep documentation for more details.
|
||||
|
||||
Documentation: https://docs.getzep.com
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
memory = ZepMemory(
|
||||
session_id=session_id, # Identifies your user or a user's session
|
||||
url=ZEP_API_URL, # Your Zep server's URL
|
||||
api_key=<your_api_key>, # Optional
|
||||
memory_key="history", # Ensure this matches the key used in
|
||||
# chain's prompt template
|
||||
return_messages=True, # Does your prompt template expect a string
|
||||
# or a list of Messages?
|
||||
)
|
||||
chain = LLMChain(memory=memory,...) # Configure your chain to use the ZepMemory
|
||||
instance
|
||||
|
||||
|
||||
Note:
|
||||
To persist metadata alongside your chat history, your will need to create a
|
||||
custom Chain class that overrides the `prep_outputs` method to include the metadata
|
||||
in the call to `self.memory.save_context`.
|
||||
|
||||
|
||||
Zep - Fast, scalable building blocks for LLM Apps
|
||||
=========
|
||||
Zep is an open source platform for productionizing LLM apps. Go from a prototype
|
||||
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
|
||||
rewriting code.
|
||||
|
||||
For server installation instructions and more, see:
|
||||
https://docs.getzep.com/deployment/quickstart/
|
||||
|
||||
For more information on the zep-python package, see:
|
||||
https://github.com/getzep/zep-python
|
||||
|
||||
"""
|
||||
|
||||
chat_memory: ZepChatMessageHistory
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
url: str = "http://localhost:8000",
|
||||
api_key: Optional[str] = None,
|
||||
output_key: Optional[str] = None,
|
||||
input_key: Optional[str] = None,
|
||||
return_messages: bool = False,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
memory_key: str = "history",
|
||||
):
|
||||
"""Initialize ZepMemory.
|
||||
|
||||
Args:
|
||||
session_id (str): Identifies your user or a user's session
|
||||
url (str, optional): Your Zep server's URL. Defaults to
|
||||
"http://localhost:8000".
|
||||
api_key (Optional[str], optional): Your Zep API key. Defaults to None.
|
||||
output_key (Optional[str], optional): The key to use for the output message.
|
||||
Defaults to None.
|
||||
input_key (Optional[str], optional): The key to use for the input message.
|
||||
Defaults to None.
|
||||
return_messages (bool, optional): Does your prompt template expect a string
|
||||
or a list of Messages? Defaults to False
|
||||
i.e. return a string.
|
||||
human_prefix (str, optional): The prefix to use for human messages.
|
||||
Defaults to "Human".
|
||||
ai_prefix (str, optional): The prefix to use for AI messages.
|
||||
Defaults to "AI".
|
||||
memory_key (str, optional): The key to use for the memory.
|
||||
Defaults to "history".
|
||||
Ensure that this matches the key used in
|
||||
chain's prompt template.
|
||||
"""
|
||||
chat_message_history = ZepChatMessageHistory(
|
||||
session_id=session_id,
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
)
|
||||
super().__init__(
|
||||
chat_memory=chat_message_history,
|
||||
output_key=output_key,
|
||||
input_key=input_key,
|
||||
return_messages=return_messages,
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
memory_key=memory_key,
|
||||
)
|
||||
|
||||
def save_context(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
outputs: Dict[str, str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): The inputs to the chain.
|
||||
outputs (Dict[str, str]): The outputs from the chain.
|
||||
metadata (Optional[Dict[str, Any]], optional): Any metadata to save with
|
||||
the context. Defaults to None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_user_message(input_str, metadata=metadata)
|
||||
self.chat_memory.add_ai_message(output_str, metadata=metadata)
|
||||
0
libs/core/langchain_core/legacy/__init__.py
Normal file
0
libs/core/langchain_core/legacy/__init__.py
Normal file
7
libs/core/langchain_core/legacy/chains/__init__.py
Normal file
7
libs/core/langchain_core/legacy/chains/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from langchain_core.legacy.chains.base import Chain
|
||||
from langchain_core.legacy.chains.llm import LLMChain
|
||||
|
||||
__all__ = [
|
||||
"Chain",
|
||||
"LLMChain",
|
||||
]
|
||||
734
libs/core/langchain_core/legacy/chains/base.py
Normal file
734
libs/core/langchain_core/legacy/chains/base.py
Normal file
@@ -0,0 +1,734 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.legacy.memory import BaseMemory
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.outputs import RunInfo
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
ensure_config,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.runnables.utils import create_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
from langchain_core.globals import get_verbose
|
||||
|
||||
return get_verbose()
|
||||
|
||||
|
||||
RUN_KEY = "__run"
|
||||
|
||||
|
||||
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
"""Abstract base class for creating structured sequences of calls to components.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
models, document retrievers, other chains, etc., and provide a simple interface
|
||||
to this sequence.
|
||||
|
||||
The Chain interface makes it easy to create apps that are:
|
||||
- Stateful: add Memory to any Chain to give it state,
|
||||
- Observable: pass Callbacks to a Chain to execute additional functionality,
|
||||
like logging, outside the main sequence of component calls,
|
||||
- Composable: the Chain API is flexible enough that it is easy to combine
|
||||
Chains with other components, including other Chains.
|
||||
|
||||
The main methods exposed by chains are:
|
||||
- `__call__`: Chains are callable. The `__call__` method is the primary way to
|
||||
execute a Chain. This takes inputs as a dictionary and returns a
|
||||
dictionary output.
|
||||
- `run`: A convenience method that takes inputs as args/kwargs and returns the
|
||||
output as a string or object. This method can only be used for a subset of
|
||||
chains and cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
Memory is a class that gets called at the start
|
||||
and at the end of every chain. At the start, memory loads variables and passes
|
||||
them along in the chain. At the end, it saves any returned variables.
|
||||
There are many different types of memory - please see memory docs
|
||||
for the full catalog."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Optional list of callback handlers (or callback manager). Defaults to None.
|
||||
Callback handlers are called throughout the lifecycle of a call to a chain,
|
||||
starting with on_chain_start, ending with on_chain_end or on_chain_error.
|
||||
Each custom chain can optionally call additional callback methods, see Callback docs
|
||||
for full details."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||
will be printed to the console. Defaults to the global `verbose` value,
|
||||
accessible via `langchain.globals.get_verbose()`."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the chain. Defaults to None.
|
||||
These tags will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None.
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED] Use `callbacks` instead."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||
)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
inputs = self.prep_inputs(input)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
self._validate_inputs(inputs)
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
inputs = await self.aprep_inputs(input)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
self._validate_inputs(inputs)
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
if values.get("callbacks") is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both callback_manager and callbacks. "
|
||||
"callback_manager is deprecated, callbacks is the preferred "
|
||||
"parameter to pass in."
|
||||
)
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""Set the chain verbosity.
|
||||
|
||||
Defaults to the global setting if not specified by the user.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Keys expected to be in the chain input."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Keys expected to be in the chain output."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.__call__`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.acall`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
return await run_in_executor(
|
||||
None, self._call, inputs, run_manager.get_sync() if run_manager else None
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
config = {
|
||||
"callbacks": callbacks,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_name": run_name,
|
||||
}
|
||||
|
||||
return self.invoke(
|
||||
inputs,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def acall(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
config = {
|
||||
"callbacks": callbacks,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_name": run_name,
|
||||
}
|
||||
return await self.ainvoke(
|
||||
inputs,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of chain inputs, including any inputs added by chain
|
||||
memory.
|
||||
outputs: Dictionary of initial chain outputs.
|
||||
return_only_outputs: Whether to only return the chain outputs. If False,
|
||||
inputs are also added to the final outputs.
|
||||
|
||||
Returns:
|
||||
A dict of the final chain outputs.
|
||||
"""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
|
||||
Returns:
|
||||
A dictionary of all inputs, including those added by the chain's memory.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
return inputs
|
||||
|
||||
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
|
||||
Returns:
|
||||
A dictionary of all inputs, including those added by the chain's memory.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = await self.memory.aload_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
return self.output_keys[0]
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Convenience method for executing chain.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this
|
||||
method expects inputs to be passed directly in as positional arguments or
|
||||
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
|
||||
with all the inputs
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
chain.run("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
chain.run(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
# Run at start to make sure this is possible/defined
|
||||
_output_key = self._run_output_key
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def arun(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Convenience method for executing chain.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this
|
||||
method expects inputs to be passed directly in as positional arguments or
|
||||
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
|
||||
with all the inputs
|
||||
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
await chain.arun("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
await chain.arun(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
elif args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (
|
||||
await self.acall(
|
||||
args[0], callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return (
|
||||
await self.acall(
|
||||
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Dictionary representation of chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
|
||||
method.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the chain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.dict(exclude_unset=True)
|
||||
# -> {"_type": "foo", "verbose": False, ...}
|
||||
"""
|
||||
_dict = super().dict(**kwargs)
|
||||
try:
|
||||
_dict["_type"] = self._chain_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
if "_type" not in chain_dict:
|
||||
raise NotImplementedError(f"Chain {self} does not support saving.")
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
@deprecated("0.1.0", alternative="batch", removal="0.2.0")
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
422
libs/core/langchain_core/legacy/chains/llm.py
Normal file
422
libs/core/langchain_core/legacy/chains/llm.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.language_models import (
|
||||
BaseLanguageModel,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.legacy.chains.base import Chain
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableBranch,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain_core.runnables.configurable import DynamicRunnable
|
||||
from langchain_core.utils.input import get_colored_text
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.1.17",
|
||||
alternative="RunnableSequence, e.g., `prompt | llm`",
|
||||
removal="0.3.0",
|
||||
)
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
This class is deprecated. See below for an example implementation using
|
||||
LangChain runnables:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = OpenAI()
|
||||
chain = prompt | llm
|
||||
|
||||
chain.invoke("your adjective here")
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: Union[
|
||||
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
|
||||
]
|
||||
"""Language model to call."""
|
||||
output_key: str = "text" #: :meta private:
|
||||
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
|
||||
"""Output parser to use.
|
||||
Defaults to one that takes the most likely string but does not change it
|
||||
otherwise."""
|
||||
return_final_only: bool = True
|
||||
"""Whether to return only the final parsed result. Defaults to True.
|
||||
If false, will return a bunch of extra information about the generation."""
|
||||
llm_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
if self.return_final_only:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "full_generation"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
callbacks = run_manager.get_child() if run_manager else None
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
else:
|
||||
generations.append([Generation(text=res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
|
||||
callbacks = run_manager.get_child() if run_manager else None
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
else:
|
||||
generations.append([Generation(text=res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if len(input_list) == 0:
|
||||
return [], stop
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
async def aprep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if len(input_list) == 0:
|
||||
return [], stop
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
async def aapply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
await run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
return self.output_key
|
||||
|
||||
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
|
||||
"""Create outputs from response."""
|
||||
result = [
|
||||
# Get the text of the top generated string.
|
||||
{
|
||||
self.output_key: self.output_parser.parse_result(generation),
|
||||
"full_generation": generation,
|
||||
}
|
||||
for generation in llm_result.generations
|
||||
]
|
||||
if self.return_final_only:
|
||||
result = [{self.output_key: r[self.output_key]} for r in result]
|
||||
return result
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = await self.agenerate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The predict_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
async def apredict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call apredict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apredict_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apply_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_generation(result)
|
||||
|
||||
def _parse_generation(
|
||||
self, generation: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in generation
|
||||
]
|
||||
else:
|
||||
return generation
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The aapply_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = await self.aapply(input_list, callbacks=callbacks)
|
||||
return self._parse_generation(result)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_chain"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
|
||||
"""Create LLMChain from LLM and template."""
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
||||
def _get_num_tokens(self, text: str) -> int:
|
||||
return _get_language_model(self.llm).get_num_tokens(text)
|
||||
|
||||
|
||||
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
||||
if isinstance(llm_like, BaseLanguageModel):
|
||||
return llm_like
|
||||
elif isinstance(llm_like, RunnableBinding):
|
||||
return _get_language_model(llm_like.bound)
|
||||
elif isinstance(llm_like, RunnableWithFallbacks):
|
||||
return _get_language_model(llm_like.runnable)
|
||||
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||
return _get_language_model(llm_like.default)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
||||
f"{type(llm_like)}"
|
||||
)
|
||||
35
libs/core/langchain_core/legacy/memory/__init__.py
Normal file
35
libs/core/langchain_core/legacy/memory/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from langchain_core.legacy.memory.base import BaseMemory
|
||||
from langchain_core.legacy.memory.buffer import (
|
||||
ConversationBufferMemory,
|
||||
ConversationStringBufferMemory,
|
||||
)
|
||||
from langchain_core.legacy.memory.buffer_window import ConversationBufferWindowMemory
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.legacy.memory.combined import CombinedMemory
|
||||
from langchain_core.legacy.memory.entity import (
|
||||
ConversationEntityMemory,
|
||||
InMemoryEntityStore,
|
||||
)
|
||||
from langchain_core.legacy.memory.readonly import ReadOnlySharedMemory
|
||||
from langchain_core.legacy.memory.simple import SimpleMemory
|
||||
from langchain_core.legacy.memory.summary import ConversationSummaryMemory
|
||||
from langchain_core.legacy.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from langchain_core.legacy.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain_core.legacy.memory.vectorstore import VectorStoreRetrieverMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"BaseChatMemory",
|
||||
"CombinedMemory",
|
||||
"ConversationBufferMemory",
|
||||
"ConversationBufferWindowMemory",
|
||||
"ConversationEntityMemory",
|
||||
"ConversationStringBufferMemory",
|
||||
"ConversationSummaryBufferMemory",
|
||||
"ConversationSummaryMemory",
|
||||
"ConversationTokenBufferMemory",
|
||||
"InMemoryEntityStore",
|
||||
"ReadOnlySharedMemory",
|
||||
"SimpleMemory",
|
||||
"VectorStoreRetrieverMemory",
|
||||
]
|
||||
83
libs/core/langchain_core/legacy/memory/base.py
Normal file
83
libs/core/langchain_core/legacy/memory/base.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""**Memory** maintains Chain state, incorporating context from past runs.
|
||||
|
||||
**Class hierarchy for Memory:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseMemory --> <name>Memory --> <name>Memory # Examples: BaseChatMemory -> MotorheadMemory
|
||||
|
||||
""" # noqa: E501
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
class BaseMemory(Serializable, ABC):
|
||||
"""Abstract base class for memory in Chains.
|
||||
|
||||
Memory refers to state in Chains. Memory can be used to store information about
|
||||
past executions of a Chain and inject that information into the inputs of
|
||||
future executions of the Chain. For example, for conversational Chains Memory
|
||||
can be used to store conversations and automatically add them to future model
|
||||
prompts so that the model has the necessary context to respond coherently to
|
||||
the latest input.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
return await run_in_executor(None, self.load_memory_variables, inputs)
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
await run_in_executor(None, self.save_context, inputs, outputs)
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await run_in_executor(None, self.clear)
|
||||
136
libs/core/langchain_core/legacy/memory/buffer.py
Normal file
136
libs/core/langchain_core/legacy/memory/buffer.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.legacy.memory.base import BaseMemory
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.legacy.memory.utils import get_prompt_input_key
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
class ConversationBufferMemory(BaseChatMemory):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
async def abuffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return (
|
||||
await self.abuffer_as_messages()
|
||||
if self.return_messages
|
||||
else await self.abuffer_as_str()
|
||||
)
|
||||
|
||||
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
|
||||
return get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
return self._buffer_as_str(self.chat_memory.messages)
|
||||
|
||||
async def abuffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
messages = await self.chat_memory.aget_messages()
|
||||
return self._buffer_as_str(messages)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
async def abuffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return await self.chat_memory.aget_messages()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.abuffer()
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
|
||||
class ConversationStringBufferMemory(BaseMemory):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
"""Prefix to use for AI generated responses."""
|
||||
buffer: str = ""
|
||||
output_key: Optional[str] = None
|
||||
input_key: Optional[str] = None
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
def validate_chains(cls, values: Dict) -> Dict:
|
||||
"""Validate that return messages is not True."""
|
||||
if values.get("return_messages", False):
|
||||
raise ValueError(
|
||||
"return_messages must be False for ConversationStringBufferMemory"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return self.load_memory_variables(inputs)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
output_key = list(outputs.keys())[0]
|
||||
else:
|
||||
output_key = self.output_key
|
||||
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
|
||||
ai = f"{self.ai_prefix}: " + outputs[output_key]
|
||||
self.buffer += "\n" + "\n".join([human, ai])
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
return self.save_context(inputs, outputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.buffer = ""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
self.clear()
|
||||
46
libs/core/langchain_core/legacy/memory/buffer_window.py
Normal file
46
libs/core/langchain_core/legacy/memory/buffer_window.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""Buffer for storing conversation memory inside a limited size window."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history" #: :meta private:
|
||||
k: int = 5
|
||||
"""Number of messages to store in buffer."""
|
||||
|
||||
@property
|
||||
def buffer(self) -> Union[str, List[BaseMessage]]:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is False."""
|
||||
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
return get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is True."""
|
||||
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
74
libs/core/langchain_core/legacy/memory/chat_memory.py
Normal file
74
libs/core/langchain_core/legacy/memory/chat_memory.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from langchain_core.chat_history import (
|
||||
BaseChatMessageHistory,
|
||||
InMemoryChatMessageHistory,
|
||||
)
|
||||
from langchain_core.legacy.memory import BaseMemory
|
||||
from langchain_core.legacy.memory.utils import get_prompt_input_key
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
class BaseChatMemory(BaseMemory, ABC):
|
||||
"""Abstract base class for chat memory."""
|
||||
|
||||
chat_memory: BaseChatMessageHistory = Field(
|
||||
default_factory=InMemoryChatMessageHistory
|
||||
)
|
||||
output_key: Optional[str] = None
|
||||
input_key: Optional[str] = None
|
||||
return_messages: bool = False
|
||||
|
||||
def _get_input_output(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) == 1:
|
||||
output_key = list(outputs.keys())[0]
|
||||
elif "output" in outputs:
|
||||
output_key = "output"
|
||||
warnings.warn(
|
||||
f"'{self.__class__.__name__}' got multiple output keys:"
|
||||
f" {outputs.keys()}. The default 'output' key is being used."
|
||||
f" If this is not desired, please manually set 'output_key'."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got multiple output keys: {outputs.keys()}, cannot "
|
||||
f"determine which to store in memory. Please set the "
|
||||
f"'output_key' explicitly."
|
||||
)
|
||||
else:
|
||||
output_key = self.output_key
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_messages(
|
||||
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||
)
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
await self.chat_memory.aadd_messages(
|
||||
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.chat_memory.clear()
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await self.chat_memory.aclear()
|
||||
81
libs/core/langchain_core/legacy/memory/combined.py
Normal file
81
libs/core/langchain_core/legacy/memory/combined.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from langchain_core.legacy.memory import BaseMemory
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.pydantic_v1 import validator
|
||||
|
||||
|
||||
class CombinedMemory(BaseMemory):
|
||||
"""Combining multiple memories' data together."""
|
||||
|
||||
memories: List[BaseMemory]
|
||||
"""For tracking all the memories that should be accessed."""
|
||||
|
||||
@validator("memories")
|
||||
def check_repeated_memory_variable(
|
||||
cls, value: List[BaseMemory]
|
||||
) -> List[BaseMemory]:
|
||||
all_variables: Set[str] = set()
|
||||
for val in value:
|
||||
overlap = all_variables.intersection(val.memory_variables)
|
||||
if overlap:
|
||||
raise ValueError(
|
||||
f"The same variables {overlap} are found in multiple"
|
||||
"memory object, which is not allowed by CombinedMemory."
|
||||
)
|
||||
all_variables |= set(val.memory_variables)
|
||||
|
||||
return value
|
||||
|
||||
@validator("memories")
|
||||
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]:
|
||||
"""Check that if memories are of type BaseChatMemory that input keys exist."""
|
||||
for val in value:
|
||||
if isinstance(val, BaseChatMemory):
|
||||
if val.input_key is None:
|
||||
warnings.warn(
|
||||
"When using CombinedMemory, "
|
||||
"input keys should be so the input is known. "
|
||||
f" Was not set on {val}"
|
||||
)
|
||||
return value
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""All the memory variables that this instance provides."""
|
||||
"""Collected from the all the linked memories."""
|
||||
|
||||
memory_variables = []
|
||||
|
||||
for memory in self.memories:
|
||||
memory_variables.extend(memory.memory_variables)
|
||||
|
||||
return memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load all vars from sub-memories."""
|
||||
memory_data: Dict[str, Any] = {}
|
||||
|
||||
# Collect vars from all sub-memories
|
||||
for memory in self.memories:
|
||||
data = memory.load_memory_variables(inputs)
|
||||
for key, value in data.items():
|
||||
if key in memory_data:
|
||||
raise ValueError(
|
||||
f"The variable {key} is repeated in the CombinedMemory."
|
||||
)
|
||||
memory_data[key] = value
|
||||
|
||||
return memory_data
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this session for every memory."""
|
||||
# Save context for all sub-memories
|
||||
for memory in self.memories:
|
||||
memory.save_context(inputs, outputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear context from this session for every memory."""
|
||||
for memory in self.memories:
|
||||
memory.clear()
|
||||
221
libs/core/langchain_core/legacy/memory/entity.py
Normal file
221
libs/core/langchain_core/legacy/memory/entity.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.legacy.chains import LLMChain
|
||||
from langchain_core.legacy.memory import BaseChatMemory
|
||||
from langchain_core.legacy.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
ENTITY_SUMMARIZATION_PROMPT,
|
||||
)
|
||||
from langchain_core.legacy.memory.utils import get_prompt_input_key
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEntityStore(BaseModel, ABC):
|
||||
"""Abstract base class for Entity store."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""Get entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
"""Set entity value in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if entity exists in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Delete all entities from store."""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryEntityStore(BaseEntityStore):
|
||||
"""In-memory Entity store."""
|
||||
|
||||
store: Dict[str, Optional[str]] = {}
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
return self.store.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
del self.store[key]
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self.store
|
||||
|
||||
def clear(self) -> None:
|
||||
return self.store.clear()
|
||||
|
||||
|
||||
class ConversationEntityMemory(BaseChatMemory):
|
||||
"""Entity extractor & summarizer memory.
|
||||
|
||||
Extracts named entities from the recent chat history and generates summaries.
|
||||
With a swappable entity store, persisting entities across conversations.
|
||||
Defaults to an in-memory entity store, and can be swapped out for a Redis,
|
||||
SQLite, or other entity store.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
|
||||
|
||||
# Cache of recently detected entity names, if any
|
||||
# It is updated when load_memory_variables is called:
|
||||
entity_cache: List[str] = []
|
||||
|
||||
# Number of recent message pairs to consider when updating entities:
|
||||
k: int = 3
|
||||
|
||||
chat_history_key: str = "history"
|
||||
|
||||
# Store to manage entity-related data:
|
||||
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""Access chat memory messages."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return ["entities", self.chat_history_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns chat history and all generated entities with summaries if available,
|
||||
and updates or clears the recent entity cache.
|
||||
|
||||
New entity name can be found when calling this method, before the entity
|
||||
summaries are generated, so the entity cache values may be empty if no entity
|
||||
descriptions are generated yet.
|
||||
"""
|
||||
|
||||
# Create an LLMChain for predicting entity names from the recent chat history:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
||||
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
|
||||
# Extract an arbitrary window of the last message pairs from
|
||||
# the chat history, where the hyperparameter k is the
|
||||
# number of message pairs:
|
||||
buffer_string = get_buffer_string(
|
||||
self.buffer[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
# Generates a comma-separated list of named entities,
|
||||
# e.g. "Jane, White House, UFO"
|
||||
# or "NONE" if no named entities are extracted:
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=inputs[prompt_input_key],
|
||||
)
|
||||
|
||||
# If no named entities are extracted, assigns an empty list.
|
||||
if output.strip() == "NONE":
|
||||
entities = []
|
||||
else:
|
||||
# Make a list of the extracted entities:
|
||||
entities = [w.strip() for w in output.split(",")]
|
||||
|
||||
# Make a dictionary of entities with summary if exists:
|
||||
entity_summaries = {}
|
||||
|
||||
for entity in entities:
|
||||
entity_summaries[entity] = self.entity_store.get(entity, "")
|
||||
|
||||
# Replaces the entity name cache with the most recently discussed entities,
|
||||
# or if no entities were extracted, clears the cache:
|
||||
self.entity_cache = entities
|
||||
|
||||
# Should we return as message objects or as a string?
|
||||
if self.return_messages:
|
||||
# Get last `k` pair of chat messages:
|
||||
buffer: Any = self.buffer[-self.k * 2 :]
|
||||
else:
|
||||
# Reuse the string we made earlier:
|
||||
buffer = buffer_string
|
||||
|
||||
return {
|
||||
self.chat_history_key: buffer,
|
||||
"entities": entity_summaries,
|
||||
}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""
|
||||
Save context from this conversation history to the entity store.
|
||||
|
||||
Generates a summary for each entity in the entity cache by prompting
|
||||
the model, and saves these summaries to the entity store.
|
||||
"""
|
||||
|
||||
super().save_context(inputs, outputs)
|
||||
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
|
||||
# Extract an arbitrary window of the last message pairs from
|
||||
# the chat history, where the hyperparameter k is the
|
||||
# number of message pairs:
|
||||
buffer_string = get_buffer_string(
|
||||
self.buffer[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
input_data = inputs[prompt_input_key]
|
||||
|
||||
# Create an LLMChain for predicting entity summarization from the context
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
|
||||
|
||||
# Generate new summaries for entities and save them in the entity store
|
||||
for entity in self.entity_cache:
|
||||
# Get existing summary if it exists
|
||||
existing_summary = self.entity_store.get(entity, "")
|
||||
output = chain.predict(
|
||||
summary=existing_summary,
|
||||
entity=entity,
|
||||
history=buffer_string,
|
||||
input=input_data,
|
||||
)
|
||||
# Save the updated summary to the entity store
|
||||
self.entity_store.set(entity, output.strip())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.chat_memory.clear()
|
||||
self.entity_cache.clear()
|
||||
self.entity_store.clear()
|
||||
165
libs/core/langchain_core/legacy/memory/prompt.py
Normal file
165
libs/core/langchain_core/legacy/memory/prompt.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# flake8: noqa
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE = """You are an assistant to a human, powered by a large language model trained by OpenAI.
|
||||
|
||||
You are designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, you are able to generate human-like text based on the input you receive, allowing you to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
||||
|
||||
You are constantly learning and improving, and your capabilities are constantly evolving. You are able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. You have access to some personalized information provided by the human in the Context section below. Additionally, you are able to generate your own text based on the input you receive, allowing you to engage in discussions and provide explanations and descriptions on a wide range of topics.
|
||||
|
||||
Overall, you are a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether the human needs help with a specific question or just wants to have a conversation about a particular topic, you are here to assist.
|
||||
|
||||
Context:
|
||||
{entities}
|
||||
|
||||
Current conversation:
|
||||
{history}
|
||||
Last line:
|
||||
Human: {input}
|
||||
You:"""
|
||||
|
||||
ENTITY_MEMORY_CONVERSATION_TEMPLATE = PromptTemplate(
|
||||
input_variables=["entities", "history", "input"],
|
||||
template=_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE,
|
||||
)
|
||||
|
||||
_DEFAULT_SUMMARIZER_TEMPLATE = """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
|
||||
|
||||
EXAMPLE
|
||||
Current summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
|
||||
|
||||
New lines of conversation:
|
||||
Human: Why do you think artificial intelligence is a force for good?
|
||||
AI: Because artificial intelligence will help humans reach their full potential.
|
||||
|
||||
New summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
|
||||
END OF EXAMPLE
|
||||
|
||||
Current summary:
|
||||
{summary}
|
||||
|
||||
New lines of conversation:
|
||||
{new_lines}
|
||||
|
||||
New summary:"""
|
||||
SUMMARY_PROMPT = PromptTemplate(
|
||||
input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
|
||||
|
||||
The conversation history is provided just in case of a coreference (e.g. "What do you know about him" where "him" is defined in a previous line) -- ignore items mentioned there that are not in the last line.
|
||||
|
||||
Return the output as a single comma-separated list, or NONE if there is nothing of note to return (e.g. the user is just issuing a greeting or having a simple conversation).
|
||||
|
||||
EXAMPLE
|
||||
Conversation history:
|
||||
Person #1: how's it going today?
|
||||
AI: "It's going great! How about you?"
|
||||
Person #1: good! busy working on Langchain. lots to do.
|
||||
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
|
||||
Last line:
|
||||
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
|
||||
Output: Langchain
|
||||
END OF EXAMPLE
|
||||
|
||||
EXAMPLE
|
||||
Conversation history:
|
||||
Person #1: how's it going today?
|
||||
AI: "It's going great! How about you?"
|
||||
Person #1: good! busy working on Langchain. lots to do.
|
||||
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
|
||||
Last line:
|
||||
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Person #2.
|
||||
Output: Langchain, Person #2
|
||||
END OF EXAMPLE
|
||||
|
||||
Conversation history (for reference only):
|
||||
{history}
|
||||
Last line of conversation (for extraction):
|
||||
Human: {input}
|
||||
|
||||
Output:"""
|
||||
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence.
|
||||
The update should only include facts that are relayed in the last line of conversation about the provided entity, and should only contain facts about the provided entity.
|
||||
|
||||
If there is no new information about the provided entity or the information is not worth noting (not an important or relevant fact to remember long-term), return the existing summary unchanged.
|
||||
|
||||
Full conversation history (for context):
|
||||
{history}
|
||||
|
||||
Entity to summarize:
|
||||
{entity}
|
||||
|
||||
Existing summary of {entity}:
|
||||
{summary}
|
||||
|
||||
Last line of conversation:
|
||||
Human: {input}
|
||||
Updated summary:"""
|
||||
|
||||
ENTITY_SUMMARIZATION_PROMPT = PromptTemplate(
|
||||
input_variables=["entity", "summary", "history", "input"],
|
||||
template=_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
KG_TRIPLE_DELIMITER = "<|>"
|
||||
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
|
||||
"You are a networked intelligence helping a human track knowledge triples"
|
||||
" about all relevant people, things, concepts, etc. and integrating"
|
||||
" them with your knowledge stored within your weights"
|
||||
" as well as that stored in a knowledge graph."
|
||||
" Extract all of the knowledge triples from the last line of conversation."
|
||||
" A knowledge triple is a clause that contains a subject, a predicate,"
|
||||
" and an object. The subject is the entity being described,"
|
||||
" the predicate is the property of the subject that is being"
|
||||
" described, and the object is the value of the property.\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: Did you hear aliens landed in Area 51?\n"
|
||||
"AI: No, I didn't hear that. What do you know about Area 51?\n"
|
||||
"Person #1: It's a secret military base in Nevada.\n"
|
||||
"AI: What do you know about Nevada?\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: It's a state in the US. It's also the number 1 producer of gold in the US.\n\n"
|
||||
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
|
||||
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: Hello.\n"
|
||||
"AI: Hi! How are you?\n"
|
||||
"Person #1: I'm good. How are you?\n"
|
||||
"AI: I'm good too.\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: I'm going to the store.\n\n"
|
||||
"Output: NONE\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: What do you know about Descartes?\n"
|
||||
"AI: Descartes was a French philosopher, mathematician, and scientist who lived in the 17th century.\n"
|
||||
"Person #1: The Descartes I'm referring to is a standup comedian and interior designer from Montreal.\n"
|
||||
"AI: Oh yes, He is a comedian and an interior designer. He has been in the industry for 30 years. His favorite food is baked bean pie.\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n"
|
||||
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"Conversation history (for reference only):\n"
|
||||
"{history}"
|
||||
"\nLast line of conversation (for extraction):\n"
|
||||
"Human: {input}\n\n"
|
||||
"Output:"
|
||||
)
|
||||
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
|
||||
)
|
||||
26
libs/core/langchain_core/legacy/memory/readonly.py
Normal file
26
libs/core/langchain_core/legacy/memory/readonly.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.legacy.memory import BaseMemory
|
||||
|
||||
|
||||
class ReadOnlySharedMemory(BaseMemory):
|
||||
"""A memory wrapper that is read-only and cannot be changed."""
|
||||
|
||||
memory: BaseMemory
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Return memory variables."""
|
||||
return self.memory.memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load memory variables from memory."""
|
||||
return self.memory.load_memory_variables(inputs)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
26
libs/core/langchain_core/legacy/memory/simple.py
Normal file
26
libs/core/langchain_core/legacy/memory/simple.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.legacy.memory import BaseMemory
|
||||
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
"""Simple memory for storing context or other information that shouldn't
|
||||
ever change between prompts.
|
||||
"""
|
||||
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed, my memory is set in stone."""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
97
libs/core/langchain_core/legacy/memory/summary.py
Normal file
97
libs/core/langchain_core/legacy/memory/summary.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.legacy.chains.llm import LLMChain
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.legacy.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
|
||||
class SummarizerMixin(BaseModel):
|
||||
"""Mixin for summarizer."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
|
||||
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Conversation summarizer to chat memory."""
|
||||
|
||||
buffer: str = ""
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
chat_memory: BaseChatMessageHistory,
|
||||
*,
|
||||
summarize_step: int = 2,
|
||||
**kwargs: Any,
|
||||
) -> ConversationSummaryMemory:
|
||||
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
|
||||
for i in range(0, len(obj.chat_memory.messages), summarize_step):
|
||||
obj.buffer = obj.predict_new_summary(
|
||||
obj.chat_memory.messages[i : i + summarize_step], obj.buffer
|
||||
)
|
||||
return obj
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||
else:
|
||||
buffer = self.buffer
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
raise ValueError(
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
return values
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self.buffer = self.predict_new_summary(
|
||||
self.chat_memory.messages[-2:], self.buffer
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.buffer = ""
|
||||
77
libs/core/langchain_core/legacy/memory/summary_buffer.py
Normal file
77
libs/core/langchain_core/legacy/memory/summary_buffer.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.legacy.memory.summary import SummarizerMixin
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Buffer with summarizer for storing conversation memory."""
|
||||
|
||||
max_token_limit: int = 2000
|
||||
moving_summary_buffer: str = ""
|
||||
memory_key: str = "history"
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer = self.buffer
|
||||
if self.moving_summary_buffer != "":
|
||||
first_messages: List[BaseMessage] = [
|
||||
self.summary_message_cls(content=self.moving_summary_buffer)
|
||||
]
|
||||
buffer = first_messages + buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
raise ValueError(
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
return values
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self.prune()
|
||||
|
||||
def prune(self) -> None:
|
||||
"""Prune buffer if it exceeds max token limit"""
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
pruned_memory, self.moving_summary_buffer
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.moving_summary_buffer = ""
|
||||
58
libs/core/langchain_core/legacy/memory/token_buffer.py
Normal file
58
libs/core/langchain_core/legacy/memory/token_buffer.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationTokenBufferMemory(BaseChatMemory):
|
||||
"""Conversation chat memory with token limit."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "history"
|
||||
max_token_limit: int = 2000
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is False."""
|
||||
return get_buffer_string(
|
||||
self.chat_memory.messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is True."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer. Pruned."""
|
||||
super().save_context(inputs, outputs)
|
||||
# Prune buffer if it exceeds max token limit
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
20
libs/core/langchain_core/legacy/memory/utils.py
Normal file
20
libs/core/langchain_core/legacy/memory/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||
"""
|
||||
Get the prompt input key.
|
||||
|
||||
Args:
|
||||
inputs: Dict[str, Any]
|
||||
memory_variables: List[str]
|
||||
|
||||
Returns:
|
||||
A prompt input key.
|
||||
"""
|
||||
# "stop" is a special key that can be passed as input but is not used to
|
||||
# format the prompt.
|
||||
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
|
||||
if len(prompt_input_keys) != 1:
|
||||
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
||||
return prompt_input_keys[0]
|
||||
100
libs/core/langchain_core/legacy/memory/vectorstore.py
Normal file
100
libs/core/langchain_core/legacy/memory/vectorstore.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Class for a VectorStore-backed memory object."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.legacy.memory.chat_memory import BaseMemory
|
||||
from langchain_core.legacy.memory.utils import get_prompt_input_key
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
|
||||
|
||||
class VectorStoreRetrieverMemory(BaseMemory):
|
||||
"""VectorStoreRetriever-backed memory."""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
"""VectorStoreRetriever object to connect to."""
|
||||
|
||||
memory_key: str = "history" #: :meta private:
|
||||
"""Key name to locate the memories in the result of load_memory_variables."""
|
||||
|
||||
input_key: Optional[str] = None
|
||||
"""Key name to index the inputs to load_memory_variables."""
|
||||
|
||||
return_docs: bool = False
|
||||
"""Whether or not to return the result of querying the database directly."""
|
||||
|
||||
exclude_input_keys: Sequence[str] = Field(default_factory=tuple)
|
||||
"""Input keys to exclude in addition to memory key when constructing the document"""
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The list of keys emitted from the load_memory_variables method."""
|
||||
return [self.memory_key]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def _documents_to_memory_variables(
|
||||
self, docs: List[Document]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
result: Union[List[Document], str]
|
||||
if not self.return_docs:
|
||||
result = "\n".join([doc.page_content for doc in docs])
|
||||
else:
|
||||
result = docs
|
||||
return {self.memory_key: result}
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.invoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
async def aload_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = await self.retriever.ainvoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
def _form_documents(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> List[Document]:
|
||||
"""Format context from this conversation to buffer."""
|
||||
# Each document should only include the current turn, not the chat history
|
||||
exclude = set(self.exclude_input_keys)
|
||||
exclude.add(self.memory_key)
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude}
|
||||
texts = [
|
||||
f"{k}: {v}"
|
||||
for k, v in list(filtered_inputs.items()) + list(outputs.items())
|
||||
]
|
||||
page_content = "\n".join(texts)
|
||||
return [Document(page_content=page_content)]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
self.retriever.add_documents(documents)
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
await self.retriever.aadd_documents(documents)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
@@ -3,6 +3,13 @@ from typing import Dict, Tuple
|
||||
# First value is the value that it is serialized as
|
||||
# Second value is the path to load it from
|
||||
SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
("langchain", "chains", "llm", "LLMChain"): (
|
||||
"langchain_core",
|
||||
"legacy",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain",
|
||||
),
|
||||
("langchain", "schema", "messages", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
|
||||
@@ -1,83 +1,5 @@
|
||||
"""**Memory** maintains Chain state, incorporating context from past runs.
|
||||
from langchain_core.legacy.memory import BaseMemory
|
||||
|
||||
**Class hierarchy for Memory:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseMemory --> <name>Memory --> <name>Memory # Examples: BaseChatMemory -> MotorheadMemory
|
||||
|
||||
""" # noqa: E501
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
class BaseMemory(Serializable, ABC):
|
||||
"""Abstract base class for memory in Chains.
|
||||
|
||||
Memory refers to state in Chains. Memory can be used to store information about
|
||||
past executions of a Chain and inject that information into the inputs of
|
||||
future executions of the Chain. For example, for conversational Chains Memory
|
||||
can be used to store conversations and automatically add them to future model
|
||||
prompts so that the model has the necessary context to respond coherently to
|
||||
the latest input.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
return await run_in_executor(None, self.load_memory_variables, inputs)
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
await run_in_executor(None, self.save_context, inputs, outputs)
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await run_in_executor(None, self.clear)
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
]
|
||||
|
||||
35
libs/core/tests/unit_tests/test_memory.py
Normal file
35
libs/core/tests/unit_tests/test_memory.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.legacy.memory import CombinedMemory, ConversationBufferMemory
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def example_memory() -> List[ConversationBufferMemory]:
|
||||
example_1 = ConversationBufferMemory(memory_key="foo")
|
||||
example_2 = ConversationBufferMemory(memory_key="bar")
|
||||
example_3 = ConversationBufferMemory(memory_key="bar")
|
||||
return [example_1, example_2, example_3]
|
||||
|
||||
|
||||
def test_basic_functionality(example_memory: List[ConversationBufferMemory]) -> None:
|
||||
"""Test basic functionality of methods exposed by class"""
|
||||
combined_memory = CombinedMemory(memories=[example_memory[0], example_memory[1]])
|
||||
assert combined_memory.memory_variables == ["foo", "bar"]
|
||||
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
|
||||
combined_memory.save_context(
|
||||
{"input": "Hello there"}, {"output": "Hello, how can I help you?"}
|
||||
)
|
||||
assert combined_memory.load_memory_variables({}) == {
|
||||
"foo": "Human: Hello there\nAI: Hello, how can I help you?",
|
||||
"bar": "Human: Hello there\nAI: Hello, how can I help you?",
|
||||
}
|
||||
combined_memory.clear()
|
||||
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
|
||||
|
||||
|
||||
def test_repeated_memory_var(example_memory: List[ConversationBufferMemory]) -> None:
|
||||
"""Test raising error when repeated memory variables found"""
|
||||
with pytest.raises(ValueError):
|
||||
CombinedMemory(memories=[example_memory[1], example_memory[2]])
|
||||
@@ -48,7 +48,7 @@ _module_lookup = {
|
||||
"GraphSparqlQAChain": "langchain.chains.graph_qa.sparql",
|
||||
"create_history_aware_retriever": "langchain.chains.history_aware_retriever",
|
||||
"HypotheticalDocumentEmbedder": "langchain.chains.hyde.base",
|
||||
"LLMChain": "langchain.chains.llm",
|
||||
"LLMChain": "langchain_core.legacy.chains.llm",
|
||||
"LLMCheckerChain": "langchain.chains.llm_checker.base",
|
||||
"LLMMathChain": "langchain.chains.llm_math.base",
|
||||
"LLMRequestsChain": "langchain.chains.llm_requests",
|
||||
|
||||
@@ -1,732 +1,6 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, Union, cast
|
||||
from langchain_core.legacy.chains.base import Chain, _get_verbosity
|
||||
|
||||
import yaml
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.outputs import RunInfo
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableSerializable,
|
||||
ensure_config,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.runnables.utils import create_model
|
||||
|
||||
from langchain.schema import RUN_KEY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
from langchain.globals import get_verbose
|
||||
|
||||
return get_verbose()
|
||||
|
||||
|
||||
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
"""Abstract base class for creating structured sequences of calls to components.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
models, document retrievers, other chains, etc., and provide a simple interface
|
||||
to this sequence.
|
||||
|
||||
The Chain interface makes it easy to create apps that are:
|
||||
- Stateful: add Memory to any Chain to give it state,
|
||||
- Observable: pass Callbacks to a Chain to execute additional functionality,
|
||||
like logging, outside the main sequence of component calls,
|
||||
- Composable: the Chain API is flexible enough that it is easy to combine
|
||||
Chains with other components, including other Chains.
|
||||
|
||||
The main methods exposed by chains are:
|
||||
- `__call__`: Chains are callable. The `__call__` method is the primary way to
|
||||
execute a Chain. This takes inputs as a dictionary and returns a
|
||||
dictionary output.
|
||||
- `run`: A convenience method that takes inputs as args/kwargs and returns the
|
||||
output as a string or object. This method can only be used for a subset of
|
||||
chains and cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
Memory is a class that gets called at the start
|
||||
and at the end of every chain. At the start, memory loads variables and passes
|
||||
them along in the chain. At the end, it saves any returned variables.
|
||||
There are many different types of memory - please see memory docs
|
||||
for the full catalog."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Optional list of callback handlers (or callback manager). Defaults to None.
|
||||
Callback handlers are called throughout the lifecycle of a call to a chain,
|
||||
starting with on_chain_start, ending with on_chain_end or on_chain_error.
|
||||
Each custom chain can optionally call additional callback methods, see Callback docs
|
||||
for full details."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||
will be printed to the console. Defaults to the global `verbose` value,
|
||||
accessible via `langchain.globals.get_verbose()`."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the chain. Defaults to None.
|
||||
These tags will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None.
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED] Use `callbacks` instead."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||
)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
inputs = self.prep_inputs(input)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
self._validate_inputs(inputs)
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
metadata = config.get("metadata")
|
||||
run_name = config.get("run_name") or self.get_name()
|
||||
include_run_info = kwargs.get("include_run_info", False)
|
||||
return_only_outputs = kwargs.get("return_only_outputs", False)
|
||||
|
||||
inputs = await self.aprep_inputs(input)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
self._validate_inputs(inputs)
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
if values.get("callbacks") is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both callback_manager and callbacks. "
|
||||
"callback_manager is deprecated, callbacks is the preferred "
|
||||
"parameter to pass in."
|
||||
)
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""Set the chain verbosity.
|
||||
|
||||
Defaults to the global setting if not specified by the user.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Keys expected to be in the chain input."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Keys expected to be in the chain output."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.__call__`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
`Chain.acall`, which is the user-facing wrapper method that handles
|
||||
callbacks configuration and some input/output processing.
|
||||
|
||||
Args:
|
||||
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
|
||||
specified in `Chain.input_keys`, including any inputs added by memory.
|
||||
run_manager: The callbacks manager that contains the callback handlers for
|
||||
this run of the chain.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
return await run_in_executor(
|
||||
None, self._call, inputs, run_manager.get_sync() if run_manager else None
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
config = {
|
||||
"callbacks": callbacks,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_name": run_name,
|
||||
}
|
||||
|
||||
return self.invoke(
|
||||
inputs,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def acall(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
return_only_outputs: Whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
|
||||
Returns:
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
config = {
|
||||
"callbacks": callbacks,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_name": run_name,
|
||||
}
|
||||
return await self.ainvoke(
|
||||
inputs,
|
||||
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
|
||||
return_only_outputs=return_only_outputs,
|
||||
include_run_info=include_run_info,
|
||||
)
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of chain inputs, including any inputs added by chain
|
||||
memory.
|
||||
outputs: Dictionary of initial chain outputs.
|
||||
return_only_outputs: Whether to only return the chain outputs. If False,
|
||||
inputs are also added to the final outputs.
|
||||
|
||||
Returns:
|
||||
A dict of the final chain outputs.
|
||||
"""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
|
||||
Returns:
|
||||
A dictionary of all inputs, including those added by the chain's memory.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
return inputs
|
||||
|
||||
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
"""Prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of raw inputs, or single input if chain expects
|
||||
only one param. Should contain all inputs specified in
|
||||
`Chain.input_keys` except for inputs that will be set by the chain's
|
||||
memory.
|
||||
|
||||
Returns:
|
||||
A dictionary of all inputs, including those added by the chain's memory.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = await self.memory.aload_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
return self.output_keys[0]
|
||||
|
||||
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Convenience method for executing chain.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this
|
||||
method expects inputs to be passed directly in as positional arguments or
|
||||
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
|
||||
with all the inputs
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
chain.run("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
chain.run(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
# Run at start to make sure this is possible/defined
|
||||
_output_key = self._run_output_key
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
|
||||
async def arun(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Convenience method for executing chain.
|
||||
|
||||
The main difference between this method and `Chain.__call__` is that this
|
||||
method expects inputs to be passed directly in as positional arguments or
|
||||
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
|
||||
with all the inputs
|
||||
|
||||
|
||||
Args:
|
||||
*args: If the chain expects a single input, it can be passed in as the
|
||||
sole positional argument.
|
||||
callbacks: Callbacks to use for this chain run. These will be called in
|
||||
addition to callbacks passed to the chain during construction, but only
|
||||
these runtime callbacks will propagate to calls to other objects.
|
||||
tags: List of string tags to pass to all callbacks. These will be passed in
|
||||
addition to tags passed to the chain during construction, but only
|
||||
these runtime tags will propagate to calls to other objects.
|
||||
**kwargs: If the chain expects multiple inputs, they can be passed in
|
||||
directly as keyword arguments.
|
||||
|
||||
Returns:
|
||||
The chain output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Suppose we have a single-input chain that takes a 'question' string:
|
||||
await chain.arun("What's the temperature in Boise, Idaho?")
|
||||
# -> "The temperature in Boise is..."
|
||||
|
||||
# Suppose we have a multi-input chain that takes a 'question' string
|
||||
# and 'context' string:
|
||||
question = "What's the temperature in Boise, Idaho?"
|
||||
context = "Weather report for Boise, Idaho on 07/03/23..."
|
||||
await chain.arun(question=question, context=context)
|
||||
# -> "The temperature in Boise is..."
|
||||
"""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
elif args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (
|
||||
await self.acall(
|
||||
args[0], callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return (
|
||||
await self.acall(
|
||||
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Dictionary representation of chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
|
||||
method.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the chain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.dict(exclude_unset=True)
|
||||
# -> {"_type": "foo", "verbose": False, ...}
|
||||
"""
|
||||
_dict = super().dict(**kwargs)
|
||||
try:
|
||||
_dict["_type"] = self._chain_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
null.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
if "_type" not in chain_dict:
|
||||
raise NotImplementedError(f"Chain {self} does not support saving.")
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
@deprecated("0.1.0", alternative="batch", removal="0.2.0")
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
__all__ = [
|
||||
"Chain",
|
||||
"_get_verbosity",
|
||||
]
|
||||
|
||||
@@ -1,423 +1,5 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from __future__ import annotations
|
||||
from langchain_core.legacy.chains import LLMChain
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.language_models import (
|
||||
BaseLanguageModel,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableBranch,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain_core.runnables.configurable import DynamicRunnable
|
||||
from langchain_core.utils.input import get_colored_text
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.1.17",
|
||||
alternative="RunnableSequence, e.g., `prompt | llm`",
|
||||
removal="0.3.0",
|
||||
)
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
This class is deprecated. See below for an example implementation using
|
||||
LangChain runnables:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = OpenAI()
|
||||
chain = prompt | llm
|
||||
|
||||
chain.invoke("your adjective here")
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: Union[
|
||||
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
|
||||
]
|
||||
"""Language model to call."""
|
||||
output_key: str = "text" #: :meta private:
|
||||
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
|
||||
"""Output parser to use.
|
||||
Defaults to one that takes the most likely string but does not change it
|
||||
otherwise."""
|
||||
return_final_only: bool = True
|
||||
"""Whether to return only the final parsed result. Defaults to True.
|
||||
If false, will return a bunch of extra information about the generation."""
|
||||
llm_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
if self.return_final_only:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "full_generation"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
callbacks = run_manager.get_child() if run_manager else None
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
else:
|
||||
generations.append([Generation(text=res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
|
||||
callbacks = run_manager.get_child() if run_manager else None
|
||||
if isinstance(self.llm, BaseLanguageModel):
|
||||
return await self.llm.agenerate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=callbacks,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
else:
|
||||
generations.append([Generation(text=res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if len(input_list) == 0:
|
||||
return [], stop
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
async def aprep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if len(input_list) == 0:
|
||||
return [], stop
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
async def aapply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
await run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def _run_output_key(self) -> str:
|
||||
return self.output_key
|
||||
|
||||
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
|
||||
"""Create outputs from response."""
|
||||
result = [
|
||||
# Get the text of the top generated string.
|
||||
{
|
||||
self.output_key: self.output_parser.parse_result(generation),
|
||||
"full_generation": generation,
|
||||
}
|
||||
for generation in llm_result.generations
|
||||
]
|
||||
if self.return_final_only:
|
||||
result = [{self.output_key: r[self.output_key]} for r in result]
|
||||
return result
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = await self.agenerate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The predict_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
async def apredict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call apredict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apredict_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = await self.apredict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apply_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_generation(result)
|
||||
|
||||
def _parse_generation(
|
||||
self, generation: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in generation
|
||||
]
|
||||
else:
|
||||
return generation
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The aapply_and_parse method is deprecated, "
|
||||
"instead pass an output parser directly to LLMChain."
|
||||
)
|
||||
result = await self.aapply(input_list, callbacks=callbacks)
|
||||
return self._parse_generation(result)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_chain"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
|
||||
"""Create LLMChain from LLM and template."""
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
||||
def _get_num_tokens(self, text: str) -> int:
|
||||
return _get_language_model(self.llm).get_num_tokens(text)
|
||||
|
||||
|
||||
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
||||
if isinstance(llm_like, BaseLanguageModel):
|
||||
return llm_like
|
||||
elif isinstance(llm_like, RunnableBinding):
|
||||
return _get_language_model(llm_like.bound)
|
||||
elif isinstance(llm_like, RunnableWithFallbacks):
|
||||
return _get_language_model(llm_like.runnable)
|
||||
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
||||
return _get_language_model(llm_like.default)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
||||
f"{type(llm_like)}"
|
||||
)
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
]
|
||||
|
||||
@@ -45,29 +45,28 @@ from langchain_community.chat_message_histories import (
|
||||
XataChatMessageHistory,
|
||||
ZepChatMessageHistory,
|
||||
)
|
||||
|
||||
from langchain.memory.buffer import (
|
||||
ConversationBufferMemory,
|
||||
ConversationStringBufferMemory,
|
||||
)
|
||||
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
||||
from langchain.memory.combined import CombinedMemory
|
||||
from langchain.memory.entity import (
|
||||
ConversationEntityMemory,
|
||||
InMemoryEntityStore,
|
||||
from langchain_community.memory import (
|
||||
ConversationKGMemory,
|
||||
MotorheadMemory,
|
||||
RedisEntityStore,
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
ZepMemory,
|
||||
)
|
||||
from langchain_core.legacy.memory import (
|
||||
CombinedMemory,
|
||||
ConversationBufferMemory,
|
||||
ConversationBufferWindowMemory,
|
||||
ConversationEntityMemory,
|
||||
ConversationStringBufferMemory,
|
||||
ConversationSummaryBufferMemory,
|
||||
ConversationSummaryMemory,
|
||||
ConversationTokenBufferMemory,
|
||||
InMemoryEntityStore,
|
||||
ReadOnlySharedMemory,
|
||||
SimpleMemory,
|
||||
VectorStoreRetrieverMemory,
|
||||
)
|
||||
from langchain.memory.kg import ConversationKGMemory
|
||||
from langchain.memory.motorhead_memory import MotorheadMemory
|
||||
from langchain.memory.readonly import ReadOnlySharedMemory
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
from langchain.memory.summary import ConversationSummaryMemory
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
||||
from langchain.memory.zep_memory import ZepMemory
|
||||
|
||||
__all__ = [
|
||||
"AstraDBChatMessageHistory",
|
||||
|
||||
@@ -1,136 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langchain_core.legacy.memory import (
|
||||
ConversationBufferMemory,
|
||||
ConversationStringBufferMemory,
|
||||
)
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
class ConversationBufferMemory(BaseChatMemory):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
async def abuffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return (
|
||||
await self.abuffer_as_messages()
|
||||
if self.return_messages
|
||||
else await self.abuffer_as_str()
|
||||
)
|
||||
|
||||
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
|
||||
return get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
return self._buffer_as_str(self.chat_memory.messages)
|
||||
|
||||
async def abuffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
messages = await self.chat_memory.aget_messages()
|
||||
return self._buffer_as_str(messages)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
async def abuffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return await self.chat_memory.aget_messages()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.abuffer()
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
|
||||
class ConversationStringBufferMemory(BaseMemory):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
"""Prefix to use for AI generated responses."""
|
||||
buffer: str = ""
|
||||
output_key: Optional[str] = None
|
||||
input_key: Optional[str] = None
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
def validate_chains(cls, values: Dict) -> Dict:
|
||||
"""Validate that return messages is not True."""
|
||||
if values.get("return_messages", False):
|
||||
raise ValueError(
|
||||
"return_messages must be False for ConversationStringBufferMemory"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return self.load_memory_variables(inputs)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
output_key = list(outputs.keys())[0]
|
||||
else:
|
||||
output_key = self.output_key
|
||||
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
|
||||
ai = f"{self.ai_prefix}: " + outputs[output_key]
|
||||
self.buffer += "\n" + "\n".join([human, ai])
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
return self.save_context(inputs, outputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.buffer = ""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
self.clear()
|
||||
__all__ = ["ConversationBufferMemory", "ConversationStringBufferMemory"]
|
||||
|
||||
@@ -1,47 +1,3 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from langchain_core.legacy.memory import ConversationBufferWindowMemory
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""Buffer for storing conversation memory inside a limited size window."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history" #: :meta private:
|
||||
k: int = 5
|
||||
"""Number of messages to store in buffer."""
|
||||
|
||||
@property
|
||||
def buffer(self) -> Union[str, List[BaseMessage]]:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is False."""
|
||||
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
return get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is True."""
|
||||
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
__all__ = ["ConversationBufferWindowMemory"]
|
||||
|
||||
@@ -1,75 +1,5 @@
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from langchain_core.chat_history import (
|
||||
BaseChatMessageHistory,
|
||||
InMemoryChatMessageHistory,
|
||||
)
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
class BaseChatMemory(BaseMemory, ABC):
|
||||
"""Abstract base class for chat memory."""
|
||||
|
||||
chat_memory: BaseChatMessageHistory = Field(
|
||||
default_factory=InMemoryChatMessageHistory
|
||||
)
|
||||
output_key: Optional[str] = None
|
||||
input_key: Optional[str] = None
|
||||
return_messages: bool = False
|
||||
|
||||
def _get_input_output(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) == 1:
|
||||
output_key = list(outputs.keys())[0]
|
||||
elif "output" in outputs:
|
||||
output_key = "output"
|
||||
warnings.warn(
|
||||
f"'{self.__class__.__name__}' got multiple output keys:"
|
||||
f" {outputs.keys()}. The default 'output' key is being used."
|
||||
f" If this is not desired, please manually set 'output_key'."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got multiple output keys: {outputs.keys()}, cannot "
|
||||
f"determine which to store in memory. Please set the "
|
||||
f"'output_key' explicitly."
|
||||
)
|
||||
else:
|
||||
output_key = self.output_key
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_messages(
|
||||
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||
)
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
await self.chat_memory.aadd_messages(
|
||||
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.chat_memory.clear()
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await self.chat_memory.aclear()
|
||||
__all__ = [
|
||||
"BaseChatMemory",
|
||||
]
|
||||
|
||||
@@ -1,82 +1,3 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Set
|
||||
from langchain_core.legacy.memory import CombinedMemory
|
||||
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.pydantic_v1 import validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
class CombinedMemory(BaseMemory):
|
||||
"""Combining multiple memories' data together."""
|
||||
|
||||
memories: List[BaseMemory]
|
||||
"""For tracking all the memories that should be accessed."""
|
||||
|
||||
@validator("memories")
|
||||
def check_repeated_memory_variable(
|
||||
cls, value: List[BaseMemory]
|
||||
) -> List[BaseMemory]:
|
||||
all_variables: Set[str] = set()
|
||||
for val in value:
|
||||
overlap = all_variables.intersection(val.memory_variables)
|
||||
if overlap:
|
||||
raise ValueError(
|
||||
f"The same variables {overlap} are found in multiple"
|
||||
"memory object, which is not allowed by CombinedMemory."
|
||||
)
|
||||
all_variables |= set(val.memory_variables)
|
||||
|
||||
return value
|
||||
|
||||
@validator("memories")
|
||||
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]:
|
||||
"""Check that if memories are of type BaseChatMemory that input keys exist."""
|
||||
for val in value:
|
||||
if isinstance(val, BaseChatMemory):
|
||||
if val.input_key is None:
|
||||
warnings.warn(
|
||||
"When using CombinedMemory, "
|
||||
"input keys should be so the input is known. "
|
||||
f" Was not set on {val}"
|
||||
)
|
||||
return value
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""All the memory variables that this instance provides."""
|
||||
"""Collected from the all the linked memories."""
|
||||
|
||||
memory_variables = []
|
||||
|
||||
for memory in self.memories:
|
||||
memory_variables.extend(memory.memory_variables)
|
||||
|
||||
return memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load all vars from sub-memories."""
|
||||
memory_data: Dict[str, Any] = {}
|
||||
|
||||
# Collect vars from all sub-memories
|
||||
for memory in self.memories:
|
||||
data = memory.load_memory_variables(inputs)
|
||||
for key, value in data.items():
|
||||
if key in memory_data:
|
||||
raise ValueError(
|
||||
f"The variable {key} is repeated in the CombinedMemory."
|
||||
)
|
||||
memory_data[key] = value
|
||||
|
||||
return memory_data
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this session for every memory."""
|
||||
# Save context for all sub-memories
|
||||
for memory in self.memories:
|
||||
memory.save_context(inputs, outputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear context from this session for every memory."""
|
||||
for memory in self.memories:
|
||||
memory.clear()
|
||||
__all__ = ["CombinedMemory"]
|
||||
|
||||
@@ -1,483 +1,19 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import islice
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain_community.utilities.redis import get_client
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
ENTITY_SUMMARIZATION_PROMPT,
|
||||
from langchain_community.memory.entity import (
|
||||
RedisEntityStore,
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEntityStore(BaseModel, ABC):
|
||||
"""Abstract base class for Entity store."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""Get entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
"""Set entity value in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if entity exists in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Delete all entities from store."""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryEntityStore(BaseEntityStore):
|
||||
"""In-memory Entity store."""
|
||||
|
||||
store: Dict[str, Optional[str]] = {}
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
return self.store.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
del self.store[key]
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self.store
|
||||
|
||||
def clear(self) -> None:
|
||||
return self.store.clear()
|
||||
|
||||
|
||||
class UpstashRedisEntityStore(BaseEntityStore):
|
||||
"""Upstash Redis backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
def scan_and_delete(cursor: int) -> int:
|
||||
cursor, keys_to_delete = self.redis_client.scan(
|
||||
cursor, f"{self.full_key_prefix}:*"
|
||||
)
|
||||
self.redis_client.delete(*keys_to_delete)
|
||||
return cursor
|
||||
|
||||
cursor = scan_and_delete(0)
|
||||
while cursor != 0:
|
||||
scan_and_delete(cursor)
|
||||
|
||||
|
||||
class RedisEntityStore(BaseEntityStore):
|
||||
"""Redis-backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
redis_client: Any
|
||||
session_id: str = "default"
|
||||
key_prefix: str = "memory_store"
|
||||
ttl: Optional[int] = 60 * 60 * 24
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "redis://localhost:6379/0",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||
except redis.exceptions.ConnectionError as error:
|
||||
logger.error(error)
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
# iterate a list in batches of size batch_size
|
||||
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
||||
iterator = iter(iterable)
|
||||
while batch := list(islice(iterator, batch_size)):
|
||||
yield batch
|
||||
|
||||
for keybatch in batched(
|
||||
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
|
||||
):
|
||||
self.redis_client.delete(*keybatch)
|
||||
|
||||
|
||||
class SQLiteEntityStore(BaseEntityStore):
|
||||
"""SQLite-backed Entity store"""
|
||||
|
||||
session_id: str = "default"
|
||||
table_name: str = "memory_store"
|
||||
conn: Any = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
db_file: str = "entities.db",
|
||||
table_name: str = "memory_store",
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import sqlite3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sqlite3 python package. "
|
||||
"Please install it with `pip install sqlite3`."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conn = sqlite3.connect(db_file)
|
||||
self.session_id = session_id
|
||||
self.table_name = table_name
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
@property
|
||||
def full_table_name(self) -> str:
|
||||
return f"{self.table_name}_{self.session_id}"
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(create_table_query)
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
query = f"""
|
||||
SELECT value
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
if result is not None:
|
||||
value = result[0]
|
||||
return value
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
query = f"""
|
||||
INSERT OR REPLACE INTO {self.full_table_name} (key, value)
|
||||
VALUES (?, ?)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key, value))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key,))
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
query = f"""
|
||||
SELECT 1
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
def clear(self) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query)
|
||||
|
||||
|
||||
class ConversationEntityMemory(BaseChatMemory):
|
||||
"""Entity extractor & summarizer memory.
|
||||
|
||||
Extracts named entities from the recent chat history and generates summaries.
|
||||
With a swappable entity store, persisting entities across conversations.
|
||||
Defaults to an in-memory entity store, and can be swapped out for a Redis,
|
||||
SQLite, or other entity store.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
|
||||
|
||||
# Cache of recently detected entity names, if any
|
||||
# It is updated when load_memory_variables is called:
|
||||
entity_cache: List[str] = []
|
||||
|
||||
# Number of recent message pairs to consider when updating entities:
|
||||
k: int = 3
|
||||
|
||||
chat_history_key: str = "history"
|
||||
|
||||
# Store to manage entity-related data:
|
||||
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""Access chat memory messages."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return ["entities", self.chat_history_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns chat history and all generated entities with summaries if available,
|
||||
and updates or clears the recent entity cache.
|
||||
|
||||
New entity name can be found when calling this method, before the entity
|
||||
summaries are generated, so the entity cache values may be empty if no entity
|
||||
descriptions are generated yet.
|
||||
"""
|
||||
|
||||
# Create an LLMChain for predicting entity names from the recent chat history:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
||||
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
|
||||
# Extract an arbitrary window of the last message pairs from
|
||||
# the chat history, where the hyperparameter k is the
|
||||
# number of message pairs:
|
||||
buffer_string = get_buffer_string(
|
||||
self.buffer[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
# Generates a comma-separated list of named entities,
|
||||
# e.g. "Jane, White House, UFO"
|
||||
# or "NONE" if no named entities are extracted:
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=inputs[prompt_input_key],
|
||||
)
|
||||
|
||||
# If no named entities are extracted, assigns an empty list.
|
||||
if output.strip() == "NONE":
|
||||
entities = []
|
||||
else:
|
||||
# Make a list of the extracted entities:
|
||||
entities = [w.strip() for w in output.split(",")]
|
||||
|
||||
# Make a dictionary of entities with summary if exists:
|
||||
entity_summaries = {}
|
||||
|
||||
for entity in entities:
|
||||
entity_summaries[entity] = self.entity_store.get(entity, "")
|
||||
|
||||
# Replaces the entity name cache with the most recently discussed entities,
|
||||
# or if no entities were extracted, clears the cache:
|
||||
self.entity_cache = entities
|
||||
|
||||
# Should we return as message objects or as a string?
|
||||
if self.return_messages:
|
||||
# Get last `k` pair of chat messages:
|
||||
buffer: Any = self.buffer[-self.k * 2 :]
|
||||
else:
|
||||
# Reuse the string we made earlier:
|
||||
buffer = buffer_string
|
||||
|
||||
return {
|
||||
self.chat_history_key: buffer,
|
||||
"entities": entity_summaries,
|
||||
}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""
|
||||
Save context from this conversation history to the entity store.
|
||||
|
||||
Generates a summary for each entity in the entity cache by prompting
|
||||
the model, and saves these summaries to the entity store.
|
||||
"""
|
||||
|
||||
super().save_context(inputs, outputs)
|
||||
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
|
||||
# Extract an arbitrary window of the last message pairs from
|
||||
# the chat history, where the hyperparameter k is the
|
||||
# number of message pairs:
|
||||
buffer_string = get_buffer_string(
|
||||
self.buffer[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
input_data = inputs[prompt_input_key]
|
||||
|
||||
# Create an LLMChain for predicting entity summarization from the context
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
|
||||
|
||||
# Generate new summaries for entities and save them in the entity store
|
||||
for entity in self.entity_cache:
|
||||
# Get existing summary if it exists
|
||||
existing_summary = self.entity_store.get(entity, "")
|
||||
output = chain.predict(
|
||||
summary=existing_summary,
|
||||
entity=entity,
|
||||
history=buffer_string,
|
||||
input=input_data,
|
||||
)
|
||||
# Save the updated summary to the entity store
|
||||
self.entity_store.set(entity, output.strip())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.chat_memory.clear()
|
||||
self.entity_cache.clear()
|
||||
self.entity_store.clear()
|
||||
from langchain_core.legacy.memory.entity import (
|
||||
BaseEntityStore,
|
||||
ConversationEntityMemory,
|
||||
InMemoryEntityStore,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseEntityStore",
|
||||
"InMemoryEntityStore",
|
||||
"UpstashRedisEntityStore",
|
||||
"RedisEntityStore",
|
||||
"SQLiteEntityStore",
|
||||
"ConversationEntityMemory",
|
||||
]
|
||||
|
||||
@@ -1,133 +1,5 @@
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
from langchain_community.memory import ConversationKGMemory
|
||||
|
||||
from langchain_community.graphs import NetworkxEntityGraph
|
||||
from langchain_community.graphs.networkx_graph import (
|
||||
KnowledgeTriple,
|
||||
get_entities,
|
||||
parse_triples,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
class ConversationKGMemory(BaseChatMemory):
|
||||
"""Knowledge graph conversation memory.
|
||||
|
||||
Integrates with external knowledge graph to store and retrieve
|
||||
information about knowledge triples in the conversation.
|
||||
"""
|
||||
|
||||
k: int = 2
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
|
||||
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
llm: BaseLanguageModel
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
"""Number of previous utterances to include in the context."""
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
entities = self._get_current_entities(inputs)
|
||||
|
||||
summary_strings = []
|
||||
for entity in entities:
|
||||
knowledge = self.kg.get_entity_knowledge(entity)
|
||||
if knowledge:
|
||||
summary = f"On {entity}: {'. '.join(knowledge)}."
|
||||
summary_strings.append(summary)
|
||||
context: Union[str, List]
|
||||
if not summary_strings:
|
||||
context = [] if self.return_messages else ""
|
||||
elif self.return_messages:
|
||||
context = [
|
||||
self.summary_message_cls(content=text) for text in summary_strings
|
||||
]
|
||||
else:
|
||||
context = "\n".join(summary_strings)
|
||||
|
||||
return {self.memory_key: context}
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
|
||||
"""Get the output key for the prompt."""
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
return list(outputs.keys())[0]
|
||||
return self.output_key
|
||||
|
||||
def get_current_entities(self, input_string: str) -> List[str]:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
||||
buffer_string = get_buffer_string(
|
||||
self.chat_memory.messages[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=input_string,
|
||||
)
|
||||
return get_entities(output)
|
||||
|
||||
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
|
||||
"""Get the current entities in the conversation."""
|
||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
||||
return self.get_current_entities(inputs[prompt_input_key])
|
||||
|
||||
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
|
||||
buffer_string = get_buffer_string(
|
||||
self.chat_memory.messages[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=input_string,
|
||||
verbose=True,
|
||||
)
|
||||
knowledge = parse_triples(output)
|
||||
return knowledge
|
||||
|
||||
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Get and update knowledge graph from the conversation history."""
|
||||
prompt_input_key = self._get_prompt_input_key(inputs)
|
||||
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
|
||||
for triple in knowledge:
|
||||
self.kg.add_triple(triple)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self._get_and_update_kg(inputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.kg.clear()
|
||||
__all__ = [
|
||||
"ConversationKGMemory",
|
||||
]
|
||||
|
||||
@@ -1,94 +1,5 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langchain_community.memory import MotorheadMemory
|
||||
|
||||
import requests
|
||||
from langchain_core.messages import get_buffer_string
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
|
||||
# LOCAL_URL = "http://localhost:8080"
|
||||
|
||||
|
||||
class MotorheadMemory(BaseChatMemory):
|
||||
"""Chat message memory backed by Motorhead service."""
|
||||
|
||||
url: str = MANAGED_URL
|
||||
timeout: int = 3000
|
||||
memory_key: str = "history"
|
||||
session_id: str
|
||||
context: Optional[str] = None
|
||||
|
||||
# Managed Params
|
||||
api_key: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
|
||||
def __get_headers(self) -> Dict[str, str]:
|
||||
is_managed = self.url == MANAGED_URL
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if is_managed and not (self.api_key and self.client_id):
|
||||
raise ValueError(
|
||||
"""
|
||||
You must provide an API key or a client ID to use the managed
|
||||
version of Motorhead. Visit https://getmetal.io for more information.
|
||||
"""
|
||||
)
|
||||
|
||||
if is_managed and self.api_key and self.client_id:
|
||||
headers["x-metal-api-key"] = self.api_key
|
||||
headers["x-metal-client-id"] = self.client_id
|
||||
|
||||
return headers
|
||||
|
||||
async def init(self) -> None:
|
||||
res = requests.get(
|
||||
f"{self.url}/sessions/{self.session_id}/memory",
|
||||
timeout=self.timeout,
|
||||
headers=self.__get_headers(),
|
||||
)
|
||||
res_data = res.json()
|
||||
res_data = res_data.get("data", res_data) # Handle Managed Version
|
||||
|
||||
messages = res_data.get("messages", [])
|
||||
context = res_data.get("context", "NONE")
|
||||
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "AI":
|
||||
self.chat_memory.add_ai_message(message["content"])
|
||||
else:
|
||||
self.chat_memory.add_user_message(message["content"])
|
||||
|
||||
if context and context != "NONE":
|
||||
self.context = context
|
||||
|
||||
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if self.return_messages:
|
||||
return {self.memory_key: self.chat_memory.messages}
|
||||
else:
|
||||
return {self.memory_key: get_buffer_string(self.chat_memory.messages)}
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return [self.memory_key]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
requests.post(
|
||||
f"{self.url}/sessions/{self.session_id}/memory",
|
||||
timeout=self.timeout,
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "Human", "content": f"{input_str}"},
|
||||
{"role": "AI", "content": f"{output_str}"},
|
||||
]
|
||||
},
|
||||
headers=self.__get_headers(),
|
||||
)
|
||||
super().save_context(inputs, outputs)
|
||||
|
||||
def delete_session(self) -> None:
|
||||
"""Delete a session"""
|
||||
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")
|
||||
__all__ = [
|
||||
"MotorheadMemory",
|
||||
]
|
||||
|
||||
@@ -1,165 +1,15 @@
|
||||
# flake8: noqa
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE = """You are an assistant to a human, powered by a large language model trained by OpenAI.
|
||||
|
||||
You are designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, you are able to generate human-like text based on the input you receive, allowing you to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
||||
|
||||
You are constantly learning and improving, and your capabilities are constantly evolving. You are able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. You have access to some personalized information provided by the human in the Context section below. Additionally, you are able to generate your own text based on the input you receive, allowing you to engage in discussions and provide explanations and descriptions on a wide range of topics.
|
||||
|
||||
Overall, you are a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether the human needs help with a specific question or just wants to have a conversation about a particular topic, you are here to assist.
|
||||
|
||||
Context:
|
||||
{entities}
|
||||
|
||||
Current conversation:
|
||||
{history}
|
||||
Last line:
|
||||
Human: {input}
|
||||
You:"""
|
||||
|
||||
ENTITY_MEMORY_CONVERSATION_TEMPLATE = PromptTemplate(
|
||||
input_variables=["entities", "history", "input"],
|
||||
template=_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE,
|
||||
from langchain_core.legacy.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
ENTITY_MEMORY_CONVERSATION_TEMPLATE,
|
||||
ENTITY_SUMMARIZATION_PROMPT,
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
|
||||
SUMMARY_PROMPT,
|
||||
)
|
||||
|
||||
_DEFAULT_SUMMARIZER_TEMPLATE = """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
|
||||
|
||||
EXAMPLE
|
||||
Current summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
|
||||
|
||||
New lines of conversation:
|
||||
Human: Why do you think artificial intelligence is a force for good?
|
||||
AI: Because artificial intelligence will help humans reach their full potential.
|
||||
|
||||
New summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
|
||||
END OF EXAMPLE
|
||||
|
||||
Current summary:
|
||||
{summary}
|
||||
|
||||
New lines of conversation:
|
||||
{new_lines}
|
||||
|
||||
New summary:"""
|
||||
SUMMARY_PROMPT = PromptTemplate(
|
||||
input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
|
||||
|
||||
The conversation history is provided just in case of a coreference (e.g. "What do you know about him" where "him" is defined in a previous line) -- ignore items mentioned there that are not in the last line.
|
||||
|
||||
Return the output as a single comma-separated list, or NONE if there is nothing of note to return (e.g. the user is just issuing a greeting or having a simple conversation).
|
||||
|
||||
EXAMPLE
|
||||
Conversation history:
|
||||
Person #1: how's it going today?
|
||||
AI: "It's going great! How about you?"
|
||||
Person #1: good! busy working on Langchain. lots to do.
|
||||
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
|
||||
Last line:
|
||||
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
|
||||
Output: Langchain
|
||||
END OF EXAMPLE
|
||||
|
||||
EXAMPLE
|
||||
Conversation history:
|
||||
Person #1: how's it going today?
|
||||
AI: "It's going great! How about you?"
|
||||
Person #1: good! busy working on Langchain. lots to do.
|
||||
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
|
||||
Last line:
|
||||
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Person #2.
|
||||
Output: Langchain, Person #2
|
||||
END OF EXAMPLE
|
||||
|
||||
Conversation history (for reference only):
|
||||
{history}
|
||||
Last line of conversation (for extraction):
|
||||
Human: {input}
|
||||
|
||||
Output:"""
|
||||
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence.
|
||||
The update should only include facts that are relayed in the last line of conversation about the provided entity, and should only contain facts about the provided entity.
|
||||
|
||||
If there is no new information about the provided entity or the information is not worth noting (not an important or relevant fact to remember long-term), return the existing summary unchanged.
|
||||
|
||||
Full conversation history (for context):
|
||||
{history}
|
||||
|
||||
Entity to summarize:
|
||||
{entity}
|
||||
|
||||
Existing summary of {entity}:
|
||||
{summary}
|
||||
|
||||
Last line of conversation:
|
||||
Human: {input}
|
||||
Updated summary:"""
|
||||
|
||||
ENTITY_SUMMARIZATION_PROMPT = PromptTemplate(
|
||||
input_variables=["entity", "summary", "history", "input"],
|
||||
template=_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
KG_TRIPLE_DELIMITER = "<|>"
|
||||
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
|
||||
"You are a networked intelligence helping a human track knowledge triples"
|
||||
" about all relevant people, things, concepts, etc. and integrating"
|
||||
" them with your knowledge stored within your weights"
|
||||
" as well as that stored in a knowledge graph."
|
||||
" Extract all of the knowledge triples from the last line of conversation."
|
||||
" A knowledge triple is a clause that contains a subject, a predicate,"
|
||||
" and an object. The subject is the entity being described,"
|
||||
" the predicate is the property of the subject that is being"
|
||||
" described, and the object is the value of the property.\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: Did you hear aliens landed in Area 51?\n"
|
||||
"AI: No, I didn't hear that. What do you know about Area 51?\n"
|
||||
"Person #1: It's a secret military base in Nevada.\n"
|
||||
"AI: What do you know about Nevada?\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: It's a state in the US. It's also the number 1 producer of gold in the US.\n\n"
|
||||
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
|
||||
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: Hello.\n"
|
||||
"AI: Hi! How are you?\n"
|
||||
"Person #1: I'm good. How are you?\n"
|
||||
"AI: I'm good too.\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: I'm going to the store.\n\n"
|
||||
"Output: NONE\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: What do you know about Descartes?\n"
|
||||
"AI: Descartes was a French philosopher, mathematician, and scientist who lived in the 17th century.\n"
|
||||
"Person #1: The Descartes I'm referring to is a standup comedian and interior designer from Montreal.\n"
|
||||
"AI: Oh yes, He is a comedian and an interior designer. He has been in the industry for 30 years. His favorite food is baked bean pie.\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n"
|
||||
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"Conversation history (for reference only):\n"
|
||||
"{history}"
|
||||
"\nLast line of conversation (for extraction):\n"
|
||||
"Human: {input}\n\n"
|
||||
"Output:"
|
||||
)
|
||||
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
|
||||
)
|
||||
__all__ = [
|
||||
"ENTITY_SUMMARIZATION_PROMPT",
|
||||
"ENTITY_EXTRACTION_PROMPT",
|
||||
"KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT",
|
||||
"ENTITY_MEMORY_CONVERSATION_TEMPLATE",
|
||||
"SUMMARY_PROMPT",
|
||||
]
|
||||
|
||||
@@ -1,26 +1,5 @@
|
||||
from typing import Any, Dict, List
|
||||
from langchain_core.legacy.memory import ReadOnlySharedMemory
|
||||
|
||||
from langchain_core.memory import BaseMemory
|
||||
|
||||
|
||||
class ReadOnlySharedMemory(BaseMemory):
|
||||
"""A memory wrapper that is read-only and cannot be changed."""
|
||||
|
||||
memory: BaseMemory
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Return memory variables."""
|
||||
return self.memory.memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load memory variables from memory."""
|
||||
return self.memory.load_memory_variables(inputs)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
__all__ = [
|
||||
"ReadOnlySharedMemory",
|
||||
]
|
||||
|
||||
@@ -1,26 +1,3 @@
|
||||
from typing import Any, Dict, List
|
||||
from langchain_core.legacy.memory import SimpleMemory
|
||||
|
||||
from langchain_core.memory import BaseMemory
|
||||
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
"""Simple memory for storing context or other information that shouldn't
|
||||
ever change between prompts.
|
||||
"""
|
||||
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed, my memory is set in stone."""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
__all__ = ["SimpleMemory"]
|
||||
|
||||
@@ -1,98 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
from langchain_core.legacy.memory.summary import (
|
||||
ConversationSummaryMemory,
|
||||
SummarizerMixin,
|
||||
)
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
|
||||
|
||||
class SummarizerMixin(BaseModel):
|
||||
"""Mixin for summarizer."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
|
||||
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Conversation summarizer to chat memory."""
|
||||
|
||||
buffer: str = ""
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
chat_memory: BaseChatMessageHistory,
|
||||
*,
|
||||
summarize_step: int = 2,
|
||||
**kwargs: Any,
|
||||
) -> ConversationSummaryMemory:
|
||||
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
|
||||
for i in range(0, len(obj.chat_memory.messages), summarize_step):
|
||||
obj.buffer = obj.predict_new_summary(
|
||||
obj.chat_memory.messages[i : i + summarize_step], obj.buffer
|
||||
)
|
||||
return obj
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||
else:
|
||||
buffer = self.buffer
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
raise ValueError(
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
return values
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self.buffer = self.predict_new_summary(
|
||||
self.chat_memory.messages[-2:], self.buffer
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.buffer = ""
|
||||
__all__ = [
|
||||
"ConversationSummaryMemory",
|
||||
"SummarizerMixin",
|
||||
]
|
||||
|
||||
@@ -1,78 +1,3 @@
|
||||
from typing import Any, Dict, List
|
||||
from langchain_core.legacy.memory import ConversationSummaryBufferMemory
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
|
||||
|
||||
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Buffer with summarizer for storing conversation memory."""
|
||||
|
||||
max_token_limit: int = 2000
|
||||
moving_summary_buffer: str = ""
|
||||
memory_key: str = "history"
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer = self.buffer
|
||||
if self.moving_summary_buffer != "":
|
||||
first_messages: List[BaseMessage] = [
|
||||
self.summary_message_cls(content=self.moving_summary_buffer)
|
||||
]
|
||||
buffer = first_messages + buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
raise ValueError(
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
return values
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self.prune()
|
||||
|
||||
def prune(self) -> None:
|
||||
"""Prune buffer if it exceeds max token limit"""
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
pruned_memory, self.moving_summary_buffer
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.moving_summary_buffer = ""
|
||||
__all__ = ["ConversationSummaryBufferMemory"]
|
||||
|
||||
@@ -1,59 +1,5 @@
|
||||
from typing import Any, Dict, List
|
||||
from langchain_core.legacy.memory import ConversationTokenBufferMemory
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
class ConversationTokenBufferMemory(BaseChatMemory):
|
||||
"""Conversation chat memory with token limit."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "history"
|
||||
max_token_limit: int = 2000
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is False."""
|
||||
return get_buffer_string(
|
||||
self.chat_memory.messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is True."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer. Pruned."""
|
||||
super().save_context(inputs, outputs)
|
||||
# Prune buffer if it exceeds max token limit
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
__all__ = [
|
||||
"ConversationTokenBufferMemory",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,3 @@
|
||||
from typing import Any, Dict, List
|
||||
from langchain_core.legacy.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||
"""
|
||||
Get the prompt input key.
|
||||
|
||||
Args:
|
||||
inputs: Dict[str, Any]
|
||||
memory_variables: List[str]
|
||||
|
||||
Returns:
|
||||
A prompt input key.
|
||||
"""
|
||||
# "stop" is a special key that can be passed as input but is not used to
|
||||
# format the prompt.
|
||||
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
|
||||
if len(prompt_input_keys) != 1:
|
||||
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
||||
return prompt_input_keys[0]
|
||||
__all__ = ["get_prompt_input_key"]
|
||||
|
||||
@@ -1,101 +1,5 @@
|
||||
"""Class for a VectorStore-backed memory object."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from langchain_core.legacy.memory import VectorStoreRetrieverMemory
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
|
||||
from langchain.memory.chat_memory import BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
class VectorStoreRetrieverMemory(BaseMemory):
|
||||
"""VectorStoreRetriever-backed memory."""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
"""VectorStoreRetriever object to connect to."""
|
||||
|
||||
memory_key: str = "history" #: :meta private:
|
||||
"""Key name to locate the memories in the result of load_memory_variables."""
|
||||
|
||||
input_key: Optional[str] = None
|
||||
"""Key name to index the inputs to load_memory_variables."""
|
||||
|
||||
return_docs: bool = False
|
||||
"""Whether or not to return the result of querying the database directly."""
|
||||
|
||||
exclude_input_keys: Sequence[str] = Field(default_factory=tuple)
|
||||
"""Input keys to exclude in addition to memory key when constructing the document"""
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The list of keys emitted from the load_memory_variables method."""
|
||||
return [self.memory_key]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def _documents_to_memory_variables(
|
||||
self, docs: List[Document]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
result: Union[List[Document], str]
|
||||
if not self.return_docs:
|
||||
result = "\n".join([doc.page_content for doc in docs])
|
||||
else:
|
||||
result = docs
|
||||
return {self.memory_key: result}
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.invoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
async def aload_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = await self.retriever.ainvoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
def _form_documents(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> List[Document]:
|
||||
"""Format context from this conversation to buffer."""
|
||||
# Each document should only include the current turn, not the chat history
|
||||
exclude = set(self.exclude_input_keys)
|
||||
exclude.add(self.memory_key)
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude}
|
||||
texts = [
|
||||
f"{k}: {v}"
|
||||
for k, v in list(filtered_inputs.items()) + list(outputs.items())
|
||||
]
|
||||
page_content = "\n".join(texts)
|
||||
return [Document(page_content=page_content)]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
self.retriever.add_documents(documents)
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
await self.retriever.aadd_documents(documents)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
__all__ = ["VectorStoreRetrieverMemory"]
|
||||
|
||||
@@ -1,125 +1,3 @@
|
||||
from __future__ import annotations
|
||||
from langchain_community.memory import ZepMemory
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_community.chat_message_histories import ZepChatMessageHistory
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
class ZepMemory(ConversationBufferMemory):
|
||||
"""Persist your chain history to the Zep MemoryStore.
|
||||
|
||||
The number of messages returned by Zep and when the Zep server summarizes chat
|
||||
histories is configurable. See the Zep documentation for more details.
|
||||
|
||||
Documentation: https://docs.getzep.com
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
memory = ZepMemory(
|
||||
session_id=session_id, # Identifies your user or a user's session
|
||||
url=ZEP_API_URL, # Your Zep server's URL
|
||||
api_key=<your_api_key>, # Optional
|
||||
memory_key="history", # Ensure this matches the key used in
|
||||
# chain's prompt template
|
||||
return_messages=True, # Does your prompt template expect a string
|
||||
# or a list of Messages?
|
||||
)
|
||||
chain = LLMChain(memory=memory,...) # Configure your chain to use the ZepMemory
|
||||
instance
|
||||
|
||||
|
||||
Note:
|
||||
To persist metadata alongside your chat history, your will need to create a
|
||||
custom Chain class that overrides the `prep_outputs` method to include the metadata
|
||||
in the call to `self.memory.save_context`.
|
||||
|
||||
|
||||
Zep - Fast, scalable building blocks for LLM Apps
|
||||
=========
|
||||
Zep is an open source platform for productionizing LLM apps. Go from a prototype
|
||||
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
|
||||
rewriting code.
|
||||
|
||||
For server installation instructions and more, see:
|
||||
https://docs.getzep.com/deployment/quickstart/
|
||||
|
||||
For more information on the zep-python package, see:
|
||||
https://github.com/getzep/zep-python
|
||||
|
||||
"""
|
||||
|
||||
chat_memory: ZepChatMessageHistory
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
url: str = "http://localhost:8000",
|
||||
api_key: Optional[str] = None,
|
||||
output_key: Optional[str] = None,
|
||||
input_key: Optional[str] = None,
|
||||
return_messages: bool = False,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
memory_key: str = "history",
|
||||
):
|
||||
"""Initialize ZepMemory.
|
||||
|
||||
Args:
|
||||
session_id (str): Identifies your user or a user's session
|
||||
url (str, optional): Your Zep server's URL. Defaults to
|
||||
"http://localhost:8000".
|
||||
api_key (Optional[str], optional): Your Zep API key. Defaults to None.
|
||||
output_key (Optional[str], optional): The key to use for the output message.
|
||||
Defaults to None.
|
||||
input_key (Optional[str], optional): The key to use for the input message.
|
||||
Defaults to None.
|
||||
return_messages (bool, optional): Does your prompt template expect a string
|
||||
or a list of Messages? Defaults to False
|
||||
i.e. return a string.
|
||||
human_prefix (str, optional): The prefix to use for human messages.
|
||||
Defaults to "Human".
|
||||
ai_prefix (str, optional): The prefix to use for AI messages.
|
||||
Defaults to "AI".
|
||||
memory_key (str, optional): The key to use for the memory.
|
||||
Defaults to "history".
|
||||
Ensure that this matches the key used in
|
||||
chain's prompt template.
|
||||
"""
|
||||
chat_message_history = ZepChatMessageHistory(
|
||||
session_id=session_id,
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
)
|
||||
super().__init__(
|
||||
chat_memory=chat_message_history,
|
||||
output_key=output_key,
|
||||
input_key=input_key,
|
||||
return_messages=return_messages,
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
memory_key=memory_key,
|
||||
)
|
||||
|
||||
def save_context(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
outputs: Dict[str, str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): The inputs to the chain.
|
||||
outputs (Dict[str, str]): The outputs from the chain.
|
||||
metadata (Optional[Dict[str, Any]], optional): Any metadata to save with
|
||||
the context. Defaults to None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_user_message(input_str, metadata=metadata)
|
||||
self.chat_memory.add_ai_message(output_str, metadata=metadata)
|
||||
__all__ = ["ZepMemory"]
|
||||
|
||||
Reference in New Issue
Block a user