Compare commits

...

13 Commits

Author SHA1 Message Date
Eugene Yurtsev
9edffaaed2 x 2024-04-23 17:20:45 -04:00
Eugene Yurtsev
d09f8eebff x 2024-04-23 17:15:37 -04:00
Eugene Yurtsev
1b19f839f9 x 2024-04-23 17:12:44 -04:00
Eugene Yurtsev
e8d99c9620 x 2024-04-23 16:57:04 -04:00
Eugene Yurtsev
c52a84c5a3 x 2024-04-23 16:55:05 -04:00
Eugene Yurtsev
1ac61323d3 x 2024-04-23 16:41:30 -04:00
Eugene Yurtsev
f82b2f4a6f x 2024-04-23 16:41:06 -04:00
Eugene Yurtsev
3755822a2d x 2024-04-23 16:35:29 -04:00
Eugene Yurtsev
017ae731d4 x 2024-04-23 16:34:50 -04:00
Eugene Yurtsev
9ac0b0026b x 2024-04-23 16:34:13 -04:00
Eugene Yurtsev
59fbe77510 x 2024-04-23 16:27:53 -04:00
Eugene Yurtsev
8aea083bf3 Merge branch 'master' into eugene/move_memories_2 2024-04-23 16:11:25 -04:00
Eugene Yurtsev
aaf376a681 x 2024-04-23 15:54:48 -04:00
47 changed files with 3195 additions and 2982 deletions

View File

@@ -0,0 +1,17 @@
from langchain_community.memory.entity import (
RedisEntityStore,
SQLiteEntityStore,
UpstashRedisEntityStore,
)
from langchain_community.memory.kg import ConversationKGMemory
from langchain_community.memory.motorhead_memory import MotorheadMemory
from langchain_community.memory.zep_memory import ZepMemory
__all__ = [
"ConversationKGMemory",
"RedisEntityStore",
"UpstashRedisEntityStore",
"SQLiteEntityStore",
"MotorheadMemory",
"ZepMemory",
]

View File

@@ -0,0 +1,268 @@
import logging
from itertools import islice
from typing import Any, Iterable, Optional
from langchain_core.legacy.memory.entity import BaseEntityStore
from langchain_community.utilities.redis import get_client
logger = logging.getLogger(__name__)
class UpstashRedisEntityStore(BaseEntityStore):
"""Upstash Redis backed Entity store.
Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""
def __init__(
self,
session_id: str = "default",
url: str = "",
token: str = "",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
from upstash_redis import Redis
except ImportError:
raise ImportError(
"Could not import upstash_redis python package. "
"Please install it with `pip install upstash_redis`."
)
super().__init__(*args, **kwargs)
try:
self.redis_client = Redis(url=url, token=token)
except Exception:
logger.error("Upstash Redis instance could not be initiated.")
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl
@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
def exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
def clear(self) -> None:
def scan_and_delete(cursor: int) -> int:
cursor, keys_to_delete = self.redis_client.scan(
cursor, f"{self.full_key_prefix}:*"
)
self.redis_client.delete(*keys_to_delete)
return cursor
cursor = scan_and_delete(0)
while cursor != 0:
scan_and_delete(cursor)
class RedisEntityStore(BaseEntityStore):
"""Redis-backed Entity store.
Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""
redis_client: Any
session_id: str = "default"
key_prefix: str = "memory_store"
ttl: Optional[int] = 60 * 60 * 24
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
def __init__(
self,
session_id: str = "default",
url: str = "redis://localhost:6379/0",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
import redis
except ImportError:
raise ImportError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
super().__init__(*args, **kwargs)
try:
self.redis_client = get_client(redis_url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl
@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
def exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
def clear(self) -> None:
# iterate a list in batches of size batch_size
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch
for keybatch in batched(
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
):
self.redis_client.delete(*keybatch)
class SQLiteEntityStore(BaseEntityStore):
"""SQLite-backed Entity store"""
session_id: str = "default"
table_name: str = "memory_store"
conn: Any = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(
self,
session_id: str = "default",
db_file: str = "entities.db",
table_name: str = "memory_store",
*args: Any,
**kwargs: Any,
):
try:
import sqlite3
except ImportError:
raise ImportError(
"Could not import sqlite3 python package. "
"Please install it with `pip install sqlite3`."
)
super().__init__(*args, **kwargs)
self.conn = sqlite3.connect(db_file)
self.session_id = session_id
self.table_name = table_name
self._create_table_if_not_exists()
@property
def full_table_name(self) -> str:
return f"{self.table_name}_{self.session_id}"
def _create_table_if_not_exists(self) -> None:
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
key TEXT PRIMARY KEY,
value TEXT
)
"""
with self.conn:
self.conn.execute(create_table_query)
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
query = f"""
SELECT value
FROM {self.full_table_name}
WHERE key = ?
"""
cursor = self.conn.execute(query, (key,))
result = cursor.fetchone()
if result is not None:
value = result[0]
return value
return default
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
query = f"""
INSERT OR REPLACE INTO {self.full_table_name} (key, value)
VALUES (?, ?)
"""
with self.conn:
self.conn.execute(query, (key, value))
def delete(self, key: str) -> None:
query = f"""
DELETE FROM {self.full_table_name}
WHERE key = ?
"""
with self.conn:
self.conn.execute(query, (key,))
def exists(self, key: str) -> bool:
query = f"""
SELECT 1
FROM {self.full_table_name}
WHERE key = ?
LIMIT 1
"""
cursor = self.conn.execute(query, (key,))
result = cursor.fetchone()
return result is not None
def clear(self) -> None:
query = f"""
DELETE FROM {self.full_table_name}
"""
with self.conn:
self.conn.execute(query)

View File

@@ -0,0 +1,133 @@
from typing import Any, Dict, List, Type, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.legacy.chains.llm import LLMChain
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.legacy.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain_core.legacy.memory.utils import get_prompt_input_key
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.graphs import NetworkxEntityGraph
from langchain_community.graphs.networkx_graph import (
KnowledgeTriple,
get_entities,
parse_triples,
)
class ConversationKGMemory(BaseChatMemory):
"""Knowledge graph conversation memory.
Integrates with external knowledge graph to store and retrieve
information about knowledge triples in the conversation.
"""
k: int = 2
human_prefix: str = "Human"
ai_prefix: str = "AI"
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLanguageModel
summary_message_cls: Type[BaseMessage] = SystemMessage
"""Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
entities = self._get_current_entities(inputs)
summary_strings = []
for entity in entities:
knowledge = self.kg.get_entity_knowledge(entity)
if knowledge:
summary = f"On {entity}: {'. '.join(knowledge)}."
summary_strings.append(summary)
context: Union[str, List]
if not summary_strings:
context = [] if self.return_messages else ""
elif self.return_messages:
context = [
self.summary_message_cls(content=text) for text in summary_strings
]
else:
context = "\n".join(summary_strings)
return {self.memory_key: context}
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
"""Get the output key for the prompt."""
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
return list(outputs.keys())[0]
return self.output_key
def get_current_entities(self, input_string: str) -> List[str]:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
)
return get_entities(output)
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""Get the current entities in the conversation."""
prompt_input_key = self._get_prompt_input_key(inputs)
return self.get_current_entities(inputs[prompt_input_key])
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
verbose=True,
)
knowledge = parse_triples(output)
return knowledge
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""Get and update knowledge graph from the conversation history."""
prompt_input_key = self._get_prompt_input_key(inputs)
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
for triple in knowledge:
self.kg.add_triple(triple)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self._get_and_update_kg(inputs)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.kg.clear()

View File

@@ -0,0 +1,92 @@
from typing import Any, Dict, List, Optional
import requests
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.messages import get_buffer_string
MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
class MotorheadMemory(BaseChatMemory):
"""Chat message memory backed by Motorhead service."""
url: str = MANAGED_URL
timeout: int = 3000
memory_key: str = "history"
session_id: str
context: Optional[str] = None
# Managed Params
api_key: Optional[str] = None
client_id: Optional[str] = None
def __get_headers(self) -> Dict[str, str]:
is_managed = self.url == MANAGED_URL
headers = {
"Content-Type": "application/json",
}
if is_managed and not (self.api_key and self.client_id):
raise ValueError(
"""
You must provide an API key or a client ID to use the managed
version of Motorhead. Visit https://getmetal.io for more information.
"""
)
if is_managed and self.api_key and self.client_id:
headers["x-metal-api-key"] = self.api_key
headers["x-metal-client-id"] = self.client_id
return headers
async def init(self) -> None:
res = requests.get(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
headers=self.__get_headers(),
)
res_data = res.json()
res_data = res_data.get("data", res_data) # Handle Managed Version
messages = res_data.get("messages", [])
context = res_data.get("context", "NONE")
for message in reversed(messages):
if message["role"] == "AI":
self.chat_memory.add_ai_message(message["content"])
else:
self.chat_memory.add_user_message(message["content"])
if context and context != "NONE":
self.context = context
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
if self.return_messages:
return {self.memory_key: self.chat_memory.messages}
else:
return {self.memory_key: get_buffer_string(self.chat_memory.messages)}
@property
def memory_variables(self) -> List[str]:
return [self.memory_key]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
input_str, output_str = self._get_input_output(inputs, outputs)
requests.post(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
json={
"messages": [
{"role": "Human", "content": f"{input_str}"},
{"role": "AI", "content": f"{output_str}"},
]
},
headers=self.__get_headers(),
)
super().save_context(inputs, outputs)
def delete_session(self) -> None:
"""Delete a session"""
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")

View File

@@ -0,0 +1,125 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from langchain_core.legacy.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import ZepChatMessageHistory
class ZepMemory(ConversationBufferMemory):
"""Persist your chain history to the Zep MemoryStore.
The number of messages returned by Zep and when the Zep server summarizes chat
histories is configurable. See the Zep documentation for more details.
Documentation: https://docs.getzep.com
Example:
.. code-block:: python
memory = ZepMemory(
session_id=session_id, # Identifies your user or a user's session
url=ZEP_API_URL, # Your Zep server's URL
api_key=<your_api_key>, # Optional
memory_key="history", # Ensure this matches the key used in
# chain's prompt template
return_messages=True, # Does your prompt template expect a string
# or a list of Messages?
)
chain = LLMChain(memory=memory,...) # Configure your chain to use the ZepMemory
instance
Note:
To persist metadata alongside your chat history, your will need to create a
custom Chain class that overrides the `prep_outputs` method to include the metadata
in the call to `self.memory.save_context`.
Zep - Fast, scalable building blocks for LLM Apps
=========
Zep is an open source platform for productionizing LLM apps. Go from a prototype
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
rewriting code.
For server installation instructions and more, see:
https://docs.getzep.com/deployment/quickstart/
For more information on the zep-python package, see:
https://github.com/getzep/zep-python
"""
chat_memory: ZepChatMessageHistory
def __init__(
self,
session_id: str,
url: str = "http://localhost:8000",
api_key: Optional[str] = None,
output_key: Optional[str] = None,
input_key: Optional[str] = None,
return_messages: bool = False,
human_prefix: str = "Human",
ai_prefix: str = "AI",
memory_key: str = "history",
):
"""Initialize ZepMemory.
Args:
session_id (str): Identifies your user or a user's session
url (str, optional): Your Zep server's URL. Defaults to
"http://localhost:8000".
api_key (Optional[str], optional): Your Zep API key. Defaults to None.
output_key (Optional[str], optional): The key to use for the output message.
Defaults to None.
input_key (Optional[str], optional): The key to use for the input message.
Defaults to None.
return_messages (bool, optional): Does your prompt template expect a string
or a list of Messages? Defaults to False
i.e. return a string.
human_prefix (str, optional): The prefix to use for human messages.
Defaults to "Human".
ai_prefix (str, optional): The prefix to use for AI messages.
Defaults to "AI".
memory_key (str, optional): The key to use for the memory.
Defaults to "history".
Ensure that this matches the key used in
chain's prompt template.
"""
chat_message_history = ZepChatMessageHistory(
session_id=session_id,
url=url,
api_key=api_key,
)
super().__init__(
chat_memory=chat_message_history,
output_key=output_key,
input_key=input_key,
return_messages=return_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
memory_key=memory_key,
)
def save_context(
self,
inputs: Dict[str, Any],
outputs: Dict[str, str],
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save context from this conversation to buffer.
Args:
inputs (Dict[str, Any]): The inputs to the chain.
outputs (Dict[str, str]): The outputs from the chain.
metadata (Optional[Dict[str, Any]], optional): Any metadata to save with
the context. Defaults to None
Returns:
None
"""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str, metadata=metadata)
self.chat_memory.add_ai_message(output_str, metadata=metadata)

View File

@@ -0,0 +1,7 @@
from langchain_core.legacy.chains.base import Chain
from langchain_core.legacy.chains.llm import LLMChain
__all__ = [
"Chain",
"LLMChain",
]

View File

@@ -0,0 +1,734 @@
"""Base interface that all chains should implement."""
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union, cast
import yaml
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
BaseCallbackManager,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.legacy.memory import BaseMemory
from langchain_core.load.dump import dumpd
from langchain_core.outputs import RunInfo
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
ensure_config,
run_in_executor,
)
from langchain_core.runnables.utils import create_model
logger = logging.getLogger(__name__)
def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose()
RUN_KEY = "__run"
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
models, document retrievers, other chains, etc., and provide a simple interface
to this sequence.
The Chain interface makes it easy to create apps that are:
- Stateful: add Memory to any Chain to give it state,
- Observable: pass Callbacks to a Chain to execute additional functionality,
like logging, outside the main sequence of component calls,
- Composable: the Chain API is flexible enough that it is easy to combine
Chains with other components, including other Chains.
The main methods exposed by chains are:
- `__call__`: Chains are callable. The `__call__` method is the primary way to
execute a Chain. This takes inputs as a dictionary and returns a
dictionary output.
- `run`: A convenience method that takes inputs as args/kwargs and returns the
output as a string or object. This method can only be used for a subset of
chains and cannot return as rich of an output as `__call__`.
"""
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start
and at the end of every chain. At the start, memory loads variables and passes
them along in the chain. At the end, it saves any returned variables.
There are many different types of memory - please see memory docs
for the full catalog."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Use `callbacks` instead."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
)
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
self._validate_inputs(inputs)
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = await self.aprep_inputs(input)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
self._validate_inputs(inputs)
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
if values.get("callbacks") is not None:
raise ValueError(
"Cannot specify both callback_manager and callbacks. "
"callback_manager is deprecated, callbacks is the preferred "
"parameter to pass in."
)
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""Set the chain verbosity.
Defaults to the global setting if not specified by the user.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Keys expected to be in the chain input."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.__call__`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.acall`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
return await run_in_executor(
None, self._call, inputs, run_manager.get_sync() if run_manager else None
)
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return self.invoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def acall(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return await self.ainvoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prepare chain outputs, and save info about this run to memory.
Args:
inputs: Dictionary of chain inputs, including any inputs added by chain
memory.
outputs: Dictionary of initial chain outputs.
return_only_outputs: Whether to only return the chain outputs. If False,
inputs are also added to the final outputs.
Returns:
A dict of the final chain outputs.
"""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
return inputs
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = await self.memory.aload_memory_variables(inputs)
inputs = dict(inputs, **external_context)
return inputs
@property
def _run_output_key(self) -> str:
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
return self.output_keys[0]
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Convenience method for executing chain.
The main difference between this method and `Chain.__call__` is that this
method expects inputs to be passed directly in as positional arguments or
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
with all the inputs
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
chain.run("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
chain.run(question=question, context=context)
# -> "The temperature in Boise is..."
"""
# Run at start to make sure this is possible/defined
_output_key = self._run_output_key
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
else:
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def arun(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Convenience method for executing chain.
The main difference between this method and `Chain.__call__` is that this
method expects inputs to be passed directly in as positional arguments or
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
with all the inputs
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
await chain.arun("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
await chain.arun(question=question, context=context)
# -> "The temperature in Boise is..."
"""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
elif args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (
await self.acall(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
if kwargs and not args:
return (
await self.acall(
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
method.
Returns:
A dictionary representation of the chain.
Example:
.. code-block:: python
chain.dict(exclude_unset=True)
# -> {"_type": "foo", "verbose": False, ...}
"""
_dict = super().dict(**kwargs)
try:
_dict["_type"] = self._chain_type
except NotImplementedError:
pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
@deprecated("0.1.0", alternative="batch", removal="0.2.0")
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]

View File

@@ -0,0 +1,422 @@
"""Chain that just formats a prompt and calls an LLM."""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.language_models import (
BaseLanguageModel,
LanguageModelInput,
)
from langchain_core.legacy.chains.base import Chain
from langchain_core.load.dump import dumpd
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableWithFallbacks,
)
from langchain_core.runnables.configurable import DynamicRunnable
from langchain_core.utils.input import get_colored_text
@deprecated(
since="0.1.17",
alternative="RunnableSequence, e.g., `prompt | llm`",
removal="0.3.0",
)
class LLMChain(Chain):
"""Chain to run queries against LLMs.
This class is deprecated. See below for an example implementation using
LangChain runnables:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = OpenAI()
chain = prompt | llm
chain.invoke("your adjective here")
Example:
.. code-block:: python
from langchain.chains import LLMChain
from langchain_community.llms import OpenAI
from langchain_core.prompts import PromptTemplate
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@classmethod
def is_lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: Union[
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
]
"""Language model to call."""
output_key: str = "text" #: :meta private:
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
"""Output parser to use.
Defaults to one that takes the most likely string but does not change it
otherwise."""
return_final_only: bool = True
"""Whether to return only the final parsed result. Defaults to True.
If false, will return a bunch of extra information about the generation."""
llm_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
if self.return_final_only:
return [self.output_key]
else:
return [self.output_key, "full_generation"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
callbacks = run_manager.get_child() if run_manager else None
if isinstance(self.llm, BaseLanguageModel):
return self.llm.generate_prompt(
prompts,
stop,
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
async def agenerate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
callbacks = run_manager.get_child() if run_manager else None
if isinstance(self.llm, BaseLanguageModel):
return await self.llm.agenerate_prompt(
prompts,
stop,
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if len(input_list) == 0:
return [], stop
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
async def aprep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if len(input_list) == 0:
return [], stop
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
async def aapply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = await self.agenerate(input_list, run_manager=run_manager)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
await run_manager.on_chain_end({"outputs": outputs})
return outputs
@property
def _run_output_key(self) -> str:
return self.output_key
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
result = [
# Get the text of the top generated string.
{
self.output_key: self.output_parser.parse_result(generation),
"full_generation": generation,
}
for generation in llm_result.generations
]
if self.return_final_only:
result = [{self.output_key: r[self.output_key]} for r in result]
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = await self.agenerate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs, callbacks=callbacks)[self.output_key]
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
"""Call predict and then parse the results."""
warnings.warn(
"The predict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
async def apredict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call apredict and then parse the results."""
warnings.warn(
"The apredict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The apply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.apply(input_list, callbacks=callbacks)
return self._parse_generation(result)
def _parse_generation(
self, generation: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key])
for res in generation
]
else:
return generation
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The aapply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.aapply(input_list, callbacks=callbacks)
return self._parse_generation(result)
@property
def _chain_type(self) -> str:
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
def _get_num_tokens(self, text: str) -> int:
return _get_language_model(self.llm).get_num_tokens(text)
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
if isinstance(llm_like, BaseLanguageModel):
return llm_like
elif isinstance(llm_like, RunnableBinding):
return _get_language_model(llm_like.bound)
elif isinstance(llm_like, RunnableWithFallbacks):
return _get_language_model(llm_like.runnable)
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
return _get_language_model(llm_like.default)
else:
raise ValueError(
f"Unable to extract BaseLanguageModel from llm_like object of type "
f"{type(llm_like)}"
)

View File

@@ -0,0 +1,35 @@
from langchain_core.legacy.memory.base import BaseMemory
from langchain_core.legacy.memory.buffer import (
ConversationBufferMemory,
ConversationStringBufferMemory,
)
from langchain_core.legacy.memory.buffer_window import ConversationBufferWindowMemory
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.legacy.memory.combined import CombinedMemory
from langchain_core.legacy.memory.entity import (
ConversationEntityMemory,
InMemoryEntityStore,
)
from langchain_core.legacy.memory.readonly import ReadOnlySharedMemory
from langchain_core.legacy.memory.simple import SimpleMemory
from langchain_core.legacy.memory.summary import ConversationSummaryMemory
from langchain_core.legacy.memory.summary_buffer import ConversationSummaryBufferMemory
from langchain_core.legacy.memory.token_buffer import ConversationTokenBufferMemory
from langchain_core.legacy.memory.vectorstore import VectorStoreRetrieverMemory
__all__ = [
"BaseMemory",
"BaseChatMemory",
"CombinedMemory",
"ConversationBufferMemory",
"ConversationBufferWindowMemory",
"ConversationEntityMemory",
"ConversationStringBufferMemory",
"ConversationSummaryBufferMemory",
"ConversationSummaryMemory",
"ConversationTokenBufferMemory",
"InMemoryEntityStore",
"ReadOnlySharedMemory",
"SimpleMemory",
"VectorStoreRetrieverMemory",
]

View File

@@ -0,0 +1,83 @@
"""**Memory** maintains Chain state, incorporating context from past runs.
**Class hierarchy for Memory:**
.. code-block::
BaseMemory --> <name>Memory --> <name>Memory # Examples: BaseChatMemory -> MotorheadMemory
""" # noqa: E501
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain_core.load.serializable import Serializable
from langchain_core.runnables import run_in_executor
class BaseMemory(Serializable, ABC):
"""Abstract base class for memory in Chains.
Memory refers to state in Chains. Memory can be used to store information about
past executions of a Chain and inject that information into the inputs of
future executions of the Chain. For example, for conversational Chains Memory
can be used to store conversations and automatically add them to future model
prompts so that the model has the necessary context to respond coherently to
the latest input.
Example:
.. code-block:: python
class SimpleMemory(BaseMemory):
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
pass
def clear(self) -> None:
pass
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
return await run_in_executor(None, self.load_memory_variables, inputs)
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this chain run to memory."""
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save the context of this chain run to memory."""
await run_in_executor(None, self.save_context, inputs, outputs)
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""
async def aclear(self) -> None:
"""Clear memory contents."""
await run_in_executor(None, self.clear)

View File

@@ -0,0 +1,136 @@
from typing import Any, Dict, List, Optional
from langchain_core.legacy.memory.base import BaseMemory
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.legacy.memory.utils import get_prompt_input_key
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.pydantic_v1 import root_validator
class ConversationBufferMemory(BaseChatMemory):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
@property
def buffer(self) -> Any:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
async def abuffer(self) -> Any:
"""String buffer of memory."""
return (
await self.abuffer_as_messages()
if self.return_messages
else await self.abuffer_as_str()
)
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
return get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
return self._buffer_as_str(self.chat_memory.messages)
async def abuffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
messages = await self.chat_memory.aget_messages()
return self._buffer_as_str(messages)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return self.chat_memory.messages
async def abuffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return await self.chat_memory.aget_messages()
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
buffer = await self.abuffer()
return {self.memory_key: buffer}
class ConversationStringBufferMemory(BaseMemory):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
buffer: str = ""
output_key: Optional[str] = None
input_key: Optional[str] = None
memory_key: str = "history" #: :meta private:
@root_validator()
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that return messages is not True."""
if values.get("return_messages", False):
raise ValueError(
"return_messages must be False for ConversationStringBufferMemory"
)
return values
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return self.load_memory_variables(inputs)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer += "\n" + "\n".join([human, ai])
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
return self.save_context(inputs, outputs)
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = ""
async def aclear(self) -> None:
self.clear()

View File

@@ -0,0 +1,46 @@
from typing import Any, Dict, List, Union
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.messages import BaseMessage, get_buffer_string
class ConversationBufferWindowMemory(BaseChatMemory):
"""Buffer for storing conversation memory inside a limited size window."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
k: int = 5
"""Number of messages to store in buffer."""
@property
def buffer(self) -> Union[str, List[BaseMessage]]:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is False."""
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
return get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is True."""
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

View File

@@ -0,0 +1,74 @@
import warnings
from abc import ABC
from typing import Any, Dict, Optional, Tuple
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.legacy.memory import BaseMemory
from langchain_core.legacy.memory.utils import get_prompt_input_key
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import Field
class BaseChatMemory(BaseMemory, ABC):
"""Abstract base class for chat memory."""
chat_memory: BaseChatMessageHistory = Field(
default_factory=InMemoryChatMessageHistory
)
output_key: Optional[str] = None
input_key: Optional[str] = None
return_messages: bool = False
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) == 1:
output_key = list(outputs.keys())[0]
elif "output" in outputs:
output_key = "output"
warnings.warn(
f"'{self.__class__.__name__}' got multiple output keys:"
f" {outputs.keys()}. The default 'output' key is being used."
f" If this is not desired, please manually set 'output_key'."
)
else:
raise ValueError(
f"Got multiple output keys: {outputs.keys()}, cannot "
f"determine which to store in memory. Please set the "
f"'output_key' explicitly."
)
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
)
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
await self.chat_memory.aadd_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
)
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
async def aclear(self) -> None:
"""Clear memory contents."""
await self.chat_memory.aclear()

View File

@@ -0,0 +1,81 @@
import warnings
from typing import Any, Dict, List, Set
from langchain_core.legacy.memory import BaseMemory
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.pydantic_v1 import validator
class CombinedMemory(BaseMemory):
"""Combining multiple memories' data together."""
memories: List[BaseMemory]
"""For tracking all the memories that should be accessed."""
@validator("memories")
def check_repeated_memory_variable(
cls, value: List[BaseMemory]
) -> List[BaseMemory]:
all_variables: Set[str] = set()
for val in value:
overlap = all_variables.intersection(val.memory_variables)
if overlap:
raise ValueError(
f"The same variables {overlap} are found in multiple"
"memory object, which is not allowed by CombinedMemory."
)
all_variables |= set(val.memory_variables)
return value
@validator("memories")
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]:
"""Check that if memories are of type BaseChatMemory that input keys exist."""
for val in value:
if isinstance(val, BaseChatMemory):
if val.input_key is None:
warnings.warn(
"When using CombinedMemory, "
"input keys should be so the input is known. "
f" Was not set on {val}"
)
return value
@property
def memory_variables(self) -> List[str]:
"""All the memory variables that this instance provides."""
"""Collected from the all the linked memories."""
memory_variables = []
for memory in self.memories:
memory_variables.extend(memory.memory_variables)
return memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load all vars from sub-memories."""
memory_data: Dict[str, Any] = {}
# Collect vars from all sub-memories
for memory in self.memories:
data = memory.load_memory_variables(inputs)
for key, value in data.items():
if key in memory_data:
raise ValueError(
f"The variable {key} is repeated in the CombinedMemory."
)
memory_data[key] = value
return memory_data
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this session for every memory."""
# Save context for all sub-memories
for memory in self.memories:
memory.save_context(inputs, outputs)
def clear(self) -> None:
"""Clear context from this session for every memory."""
for memory in self.memories:
memory.clear()

View File

@@ -0,0 +1,221 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.legacy.chains import LLMChain
from langchain_core.legacy.memory import BaseChatMemory
from langchain_core.legacy.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
ENTITY_SUMMARIZATION_PROMPT,
)
from langchain_core.legacy.memory.utils import get_prompt_input_key
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
logger = logging.getLogger(__name__)
class BaseEntityStore(BaseModel, ABC):
"""Abstract base class for Entity store."""
@abstractmethod
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
"""Get entity value from store."""
pass
@abstractmethod
def set(self, key: str, value: Optional[str]) -> None:
"""Set entity value in store."""
pass
@abstractmethod
def delete(self, key: str) -> None:
"""Delete entity value from store."""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""Check if entity exists in store."""
pass
@abstractmethod
def clear(self) -> None:
"""Delete all entities from store."""
pass
class InMemoryEntityStore(BaseEntityStore):
"""In-memory Entity store."""
store: Dict[str, Optional[str]] = {}
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
return self.store.get(key, default)
def set(self, key: str, value: Optional[str]) -> None:
self.store[key] = value
def delete(self, key: str) -> None:
del self.store[key]
def exists(self, key: str) -> bool:
return key in self.store
def clear(self) -> None:
return self.store.clear()
class ConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer memory.
Extracts named entities from the recent chat history and generates summaries.
With a swappable entity store, persisting entities across conversations.
Defaults to an in-memory entity store, and can be swapped out for a Redis,
SQLite, or other entity store.
"""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
# Cache of recently detected entity names, if any
# It is updated when load_memory_variables is called:
entity_cache: List[str] = []
# Number of recent message pairs to consider when updating entities:
k: int = 3
chat_history_key: str = "history"
# Store to manage entity-related data:
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
@property
def buffer(self) -> List[BaseMessage]:
"""Access chat memory messages."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return ["entities", self.chat_history_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Returns chat history and all generated entities with summaries if available,
and updates or clears the recent entity cache.
New entity name can be found when calling this method, before the entity
summaries are generated, so the entity cache values may be empty if no entity
descriptions are generated yet.
"""
# Create an LLMChain for predicting entity names from the recent chat history:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
# Extract an arbitrary window of the last message pairs from
# the chat history, where the hyperparameter k is the
# number of message pairs:
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
# Generates a comma-separated list of named entities,
# e.g. "Jane, White House, UFO"
# or "NONE" if no named entities are extracted:
output = chain.predict(
history=buffer_string,
input=inputs[prompt_input_key],
)
# If no named entities are extracted, assigns an empty list.
if output.strip() == "NONE":
entities = []
else:
# Make a list of the extracted entities:
entities = [w.strip() for w in output.split(",")]
# Make a dictionary of entities with summary if exists:
entity_summaries = {}
for entity in entities:
entity_summaries[entity] = self.entity_store.get(entity, "")
# Replaces the entity name cache with the most recently discussed entities,
# or if no entities were extracted, clears the cache:
self.entity_cache = entities
# Should we return as message objects or as a string?
if self.return_messages:
# Get last `k` pair of chat messages:
buffer: Any = self.buffer[-self.k * 2 :]
else:
# Reuse the string we made earlier:
buffer = buffer_string
return {
self.chat_history_key: buffer,
"entities": entity_summaries,
}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""
Save context from this conversation history to the entity store.
Generates a summary for each entity in the entity cache by prompting
the model, and saves these summaries to the entity store.
"""
super().save_context(inputs, outputs)
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
# Extract an arbitrary window of the last message pairs from
# the chat history, where the hyperparameter k is the
# number of message pairs:
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
input_data = inputs[prompt_input_key]
# Create an LLMChain for predicting entity summarization from the context
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
# Generate new summaries for entities and save them in the entity store
for entity in self.entity_cache:
# Get existing summary if it exists
existing_summary = self.entity_store.get(entity, "")
output = chain.predict(
summary=existing_summary,
entity=entity,
history=buffer_string,
input=input_data,
)
# Save the updated summary to the entity store
self.entity_store.set(entity, output.strip())
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.entity_cache.clear()
self.entity_store.clear()

View File

@@ -0,0 +1,165 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE = """You are an assistant to a human, powered by a large language model trained by OpenAI.
You are designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, you are able to generate human-like text based on the input you receive, allowing you to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
You are constantly learning and improving, and your capabilities are constantly evolving. You are able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. You have access to some personalized information provided by the human in the Context section below. Additionally, you are able to generate your own text based on the input you receive, allowing you to engage in discussions and provide explanations and descriptions on a wide range of topics.
Overall, you are a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether the human needs help with a specific question or just wants to have a conversation about a particular topic, you are here to assist.
Context:
{entities}
Current conversation:
{history}
Last line:
Human: {input}
You:"""
ENTITY_MEMORY_CONVERSATION_TEMPLATE = PromptTemplate(
input_variables=["entities", "history", "input"],
template=_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE,
)
_DEFAULT_SUMMARIZER_TEMPLATE = """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
EXAMPLE
Current summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
New lines of conversation:
Human: Why do you think artificial intelligence is a force for good?
AI: Because artificial intelligence will help humans reach their full potential.
New summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
END OF EXAMPLE
Current summary:
{summary}
New lines of conversation:
{new_lines}
New summary:"""
SUMMARY_PROMPT = PromptTemplate(
input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE
)
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
The conversation history is provided just in case of a coreference (e.g. "What do you know about him" where "him" is defined in a previous line) -- ignore items mentioned there that are not in the last line.
Return the output as a single comma-separated list, or NONE if there is nothing of note to return (e.g. the user is just issuing a greeting or having a simple conversation).
EXAMPLE
Conversation history:
Person #1: how's it going today?
AI: "It's going great! How about you?"
Person #1: good! busy working on Langchain. lots to do.
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
Last line:
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
Output: Langchain
END OF EXAMPLE
EXAMPLE
Conversation history:
Person #1: how's it going today?
AI: "It's going great! How about you?"
Person #1: good! busy working on Langchain. lots to do.
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
Last line:
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Person #2.
Output: Langchain, Person #2
END OF EXAMPLE
Conversation history (for reference only):
{history}
Last line of conversation (for extraction):
Human: {input}
Output:"""
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
)
_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence.
The update should only include facts that are relayed in the last line of conversation about the provided entity, and should only contain facts about the provided entity.
If there is no new information about the provided entity or the information is not worth noting (not an important or relevant fact to remember long-term), return the existing summary unchanged.
Full conversation history (for context):
{history}
Entity to summarize:
{entity}
Existing summary of {entity}:
{summary}
Last line of conversation:
Human: {input}
Updated summary:"""
ENTITY_SUMMARIZATION_PROMPT = PromptTemplate(
input_variables=["entity", "summary", "history", "input"],
template=_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE,
)
KG_TRIPLE_DELIMITER = "<|>"
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
"You are a networked intelligence helping a human track knowledge triples"
" about all relevant people, things, concepts, etc. and integrating"
" them with your knowledge stored within your weights"
" as well as that stored in a knowledge graph."
" Extract all of the knowledge triples from the last line of conversation."
" A knowledge triple is a clause that contains a subject, a predicate,"
" and an object. The subject is the entity being described,"
" the predicate is the property of the subject that is being"
" described, and the object is the value of the property.\n\n"
"EXAMPLE\n"
"Conversation history:\n"
"Person #1: Did you hear aliens landed in Area 51?\n"
"AI: No, I didn't hear that. What do you know about Area 51?\n"
"Person #1: It's a secret military base in Nevada.\n"
"AI: What do you know about Nevada?\n"
"Last line of conversation:\n"
"Person #1: It's a state in the US. It's also the number 1 producer of gold in the US.\n\n"
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"Conversation history:\n"
"Person #1: Hello.\n"
"AI: Hi! How are you?\n"
"Person #1: I'm good. How are you?\n"
"AI: I'm good too.\n"
"Last line of conversation:\n"
"Person #1: I'm going to the store.\n\n"
"Output: NONE\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"Conversation history:\n"
"Person #1: What do you know about Descartes?\n"
"AI: Descartes was a French philosopher, mathematician, and scientist who lived in the 17th century.\n"
"Person #1: The Descartes I'm referring to is a standup comedian and interior designer from Montreal.\n"
"AI: Oh yes, He is a comedian and an interior designer. He has been in the industry for 30 years. His favorite food is baked bean pie.\n"
"Last line of conversation:\n"
"Person #1: Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n"
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n"
"END OF EXAMPLE\n\n"
"Conversation history (for reference only):\n"
"{history}"
"\nLast line of conversation (for extraction):\n"
"Human: {input}\n\n"
"Output:"
)
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["history", "input"],
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
)

View File

@@ -0,0 +1,26 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory import BaseMemory
class ReadOnlySharedMemory(BaseMemory):
"""A memory wrapper that is read-only and cannot be changed."""
memory: BaseMemory
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return self.memory.memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load memory variables from memory."""
return self.memory.load_memory_variables(inputs)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass

View File

@@ -0,0 +1,26 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory import BaseMemory
class SimpleMemory(BaseMemory):
"""Simple memory for storing context or other information that shouldn't
ever change between prompts.
"""
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed, my memory is set in stone."""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from typing import Any, Dict, List, Type
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.legacy.chains.llm import LLMChain
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.legacy.memory.prompt import SUMMARY_PROMPT
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
class SummarizerMixin(BaseModel):
"""Mixin for summarizer."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
prompt: BasePromptTemplate = SUMMARY_PROMPT
summary_message_cls: Type[BaseMessage] = SystemMessage
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.prompt)
return chain.predict(summary=existing_summary, new_lines=new_lines)
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""Conversation summarizer to chat memory."""
buffer: str = ""
memory_key: str = "history" #: :meta private:
@classmethod
def from_messages(
cls,
llm: BaseLanguageModel,
chat_memory: BaseChatMessageHistory,
*,
summarize_step: int = 2,
**kwargs: Any,
) -> ConversationSummaryMemory:
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
for i in range(0, len(obj.chat_memory.messages), summarize_step):
obj.buffer = obj.predict_new_summary(
obj.chat_memory.messages[i : i + summarize_step], obj.buffer
)
return obj
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
if self.return_messages:
buffer: Any = [self.summary_message_cls(content=self.buffer)]
else:
buffer = self.buffer
return {self.memory_key: buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables
expected_keys = {"summary", "new_lines"}
if expected_keys != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but it should have {expected_keys}."
)
return values
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self.buffer = self.predict_new_summary(
self.chat_memory.messages[-2:], self.buffer
)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.buffer = ""

View File

@@ -0,0 +1,77 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.legacy.memory.summary import SummarizerMixin
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.pydantic_v1 import root_validator
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
"""Buffer with summarizer for storing conversation memory."""
max_token_limit: int = 2000
moving_summary_buffer: str = ""
memory_key: str = "history"
@property
def buffer(self) -> List[BaseMessage]:
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
buffer = self.buffer
if self.moving_summary_buffer != "":
first_messages: List[BaseMessage] = [
self.summary_message_cls(content=self.moving_summary_buffer)
]
buffer = first_messages + buffer
if self.return_messages:
final_buffer: Any = buffer
else:
final_buffer = get_buffer_string(
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
)
return {self.memory_key: final_buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables
expected_keys = {"summary", "new_lines"}
if expected_keys != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but it should have {expected_keys}."
)
return values
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self.prune()
def prune(self) -> None:
"""Prune buffer if it exceeds max token limit"""
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
self.moving_summary_buffer = self.predict_new_summary(
pruned_memory, self.moving_summary_buffer
)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.moving_summary_buffer = ""

View File

@@ -0,0 +1,58 @@
from typing import Any, Dict, List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.messages import BaseMessage, get_buffer_string
class ConversationTokenBufferMemory(BaseChatMemory):
"""Conversation chat memory with token limit."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
memory_key: str = "history"
max_token_limit: int = 2000
@property
def buffer(self) -> Any:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is False."""
return get_buffer_string(
self.chat_memory.messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is True."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer. Pruned."""
super().save_context(inputs, outputs)
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)

View File

@@ -0,0 +1,20 @@
from typing import Any, Dict, List
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""
Get the prompt input key.
Args:
inputs: Dict[str, Any]
memory_variables: List[str]
Returns:
A prompt input key.
"""
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
return prompt_input_keys[0]

View File

@@ -0,0 +1,100 @@
"""Class for a VectorStore-backed memory object."""
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.documents import Document
from langchain_core.legacy.memory.chat_memory import BaseMemory
from langchain_core.legacy.memory.utils import get_prompt_input_key
from langchain_core.pydantic_v1 import Field
from langchain_core.vectorstores import VectorStoreRetriever
class VectorStoreRetrieverMemory(BaseMemory):
"""VectorStoreRetriever-backed memory."""
retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStoreRetriever object to connect to."""
memory_key: str = "history" #: :meta private:
"""Key name to locate the memories in the result of load_memory_variables."""
input_key: Optional[str] = None
"""Key name to index the inputs to load_memory_variables."""
return_docs: bool = False
"""Whether or not to return the result of querying the database directly."""
exclude_input_keys: Sequence[str] = Field(default_factory=tuple)
"""Input keys to exclude in addition to memory key when constructing the document"""
@property
def memory_variables(self) -> List[str]:
"""The list of keys emitted from the load_memory_variables method."""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _documents_to_memory_variables(
self, docs: List[Document]
) -> Dict[str, Union[List[Document], str]]:
result: Union[List[Document], str]
if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs])
else:
result = docs
return {self.memory_key: result}
def load_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.invoke(query)
return self._documents_to_memory_variables(docs)
async def aload_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = await self.retriever.ainvoke(query)
return self._documents_to_memory_variables(docs)
def _form_documents(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> List[Document]:
"""Format context from this conversation to buffer."""
# Each document should only include the current turn, not the chat history
exclude = set(self.exclude_input_keys)
exclude.add(self.memory_key)
filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude}
texts = [
f"{k}: {v}"
for k, v in list(filtered_inputs.items()) + list(outputs.items())
]
page_content = "\n".join(texts)
return [Document(page_content=page_content)]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
documents = self._form_documents(inputs, outputs)
self.retriever.add_documents(documents)
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
documents = self._form_documents(inputs, outputs)
await self.retriever.aadd_documents(documents)
def clear(self) -> None:
"""Nothing to clear."""
async def aclear(self) -> None:
"""Nothing to clear."""

View File

@@ -3,6 +3,13 @@ from typing import Dict, Tuple
# First value is the value that it is serialized as
# Second value is the path to load it from
SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
("langchain", "chains", "llm", "LLMChain"): (
"langchain_core",
"legacy",
"chains",
"llm",
"LLMChain",
),
("langchain", "schema", "messages", "AIMessage"): (
"langchain_core",
"messages",

View File

@@ -1,83 +1,5 @@
"""**Memory** maintains Chain state, incorporating context from past runs.
from langchain_core.legacy.memory import BaseMemory
**Class hierarchy for Memory:**
.. code-block::
BaseMemory --> <name>Memory --> <name>Memory # Examples: BaseChatMemory -> MotorheadMemory
""" # noqa: E501
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain_core.load.serializable import Serializable
from langchain_core.runnables import run_in_executor
class BaseMemory(Serializable, ABC):
"""Abstract base class for memory in Chains.
Memory refers to state in Chains. Memory can be used to store information about
past executions of a Chain and inject that information into the inputs of
future executions of the Chain. For example, for conversational Chains Memory
can be used to store conversations and automatically add them to future model
prompts so that the model has the necessary context to respond coherently to
the latest input.
Example:
.. code-block:: python
class SimpleMemory(BaseMemory):
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
pass
def clear(self) -> None:
pass
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
return await run_in_executor(None, self.load_memory_variables, inputs)
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this chain run to memory."""
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save the context of this chain run to memory."""
await run_in_executor(None, self.save_context, inputs, outputs)
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""
async def aclear(self) -> None:
"""Clear memory contents."""
await run_in_executor(None, self.clear)
__all__ = [
"BaseMemory",
]

View File

@@ -0,0 +1,35 @@
from typing import List
import pytest
from langchain_core.legacy.memory import CombinedMemory, ConversationBufferMemory
@pytest.fixture()
def example_memory() -> List[ConversationBufferMemory]:
example_1 = ConversationBufferMemory(memory_key="foo")
example_2 = ConversationBufferMemory(memory_key="bar")
example_3 = ConversationBufferMemory(memory_key="bar")
return [example_1, example_2, example_3]
def test_basic_functionality(example_memory: List[ConversationBufferMemory]) -> None:
"""Test basic functionality of methods exposed by class"""
combined_memory = CombinedMemory(memories=[example_memory[0], example_memory[1]])
assert combined_memory.memory_variables == ["foo", "bar"]
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
combined_memory.save_context(
{"input": "Hello there"}, {"output": "Hello, how can I help you?"}
)
assert combined_memory.load_memory_variables({}) == {
"foo": "Human: Hello there\nAI: Hello, how can I help you?",
"bar": "Human: Hello there\nAI: Hello, how can I help you?",
}
combined_memory.clear()
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
def test_repeated_memory_var(example_memory: List[ConversationBufferMemory]) -> None:
"""Test raising error when repeated memory variables found"""
with pytest.raises(ValueError):
CombinedMemory(memories=[example_memory[1], example_memory[2]])

View File

@@ -48,7 +48,7 @@ _module_lookup = {
"GraphSparqlQAChain": "langchain.chains.graph_qa.sparql",
"create_history_aware_retriever": "langchain.chains.history_aware_retriever",
"HypotheticalDocumentEmbedder": "langchain.chains.hyde.base",
"LLMChain": "langchain.chains.llm",
"LLMChain": "langchain_core.legacy.chains.llm",
"LLMCheckerChain": "langchain.chains.llm_checker.base",
"LLMMathChain": "langchain.chains.llm_math.base",
"LLMRequestsChain": "langchain.chains.llm_requests",

View File

@@ -1,732 +1,6 @@
"""Base interface that all chains should implement."""
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union, cast
from langchain_core.legacy.chains.base import Chain, _get_verbosity
import yaml
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
BaseCallbackManager,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.load.dump import dumpd
from langchain_core.memory import BaseMemory
from langchain_core.outputs import RunInfo
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
from langchain_core.runnables import (
RunnableConfig,
RunnableSerializable,
ensure_config,
run_in_executor,
)
from langchain_core.runnables.utils import create_model
from langchain.schema import RUN_KEY
logger = logging.getLogger(__name__)
def _get_verbosity() -> bool:
from langchain.globals import get_verbose
return get_verbose()
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
models, document retrievers, other chains, etc., and provide a simple interface
to this sequence.
The Chain interface makes it easy to create apps that are:
- Stateful: add Memory to any Chain to give it state,
- Observable: pass Callbacks to a Chain to execute additional functionality,
like logging, outside the main sequence of component calls,
- Composable: the Chain API is flexible enough that it is easy to combine
Chains with other components, including other Chains.
The main methods exposed by chains are:
- `__call__`: Chains are callable. The `__call__` method is the primary way to
execute a Chain. This takes inputs as a dictionary and returns a
dictionary output.
- `run`: A convenience method that takes inputs as args/kwargs and returns the
output as a string or object. This method can only be used for a subset of
chains and cannot return as rich of an output as `__call__`.
"""
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start
and at the end of every chain. At the start, memory loads variables and passes
them along in the chain. At the end, it saves any returned variables.
There are many different types of memory - please see memory docs
for the full catalog."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Optional list of callback handlers (or callback manager). Defaults to None.
Callback handlers are called throughout the lifecycle of a call to a chain,
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Use `callbacks` instead."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
)
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
self._validate_inputs(inputs)
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = await self.aprep_inputs(input)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
self._validate_inputs(inputs)
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
if values.get("callbacks") is not None:
raise ValueError(
"Cannot specify both callback_manager and callbacks. "
"callback_manager is deprecated, callbacks is the preferred "
"parameter to pass in."
)
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""Set the chain verbosity.
Defaults to the global setting if not specified by the user.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Keys expected to be in the chain input."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.__call__`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
This is a private method that is not user-facing. It is only called within
`Chain.acall`, which is the user-facing wrapper method that handles
callbacks configuration and some input/output processing.
Args:
inputs: A dict of named inputs to the chain. Assumed to contain all inputs
specified in `Chain.input_keys`, including any inputs added by memory.
run_manager: The callbacks manager that contains the callback handlers for
this run of the chain.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
return await run_in_executor(
None, self._call, inputs, run_manager.get_sync() if run_manager else None
)
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return self.invoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def acall(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
return_only_outputs: Whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
Returns:
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
config = {
"callbacks": callbacks,
"tags": tags,
"metadata": metadata,
"run_name": run_name,
}
return await self.ainvoke(
inputs,
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
return_only_outputs=return_only_outputs,
include_run_info=include_run_info,
)
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prepare chain outputs, and save info about this run to memory.
Args:
inputs: Dictionary of chain inputs, including any inputs added by chain
memory.
outputs: Dictionary of initial chain outputs.
return_only_outputs: Whether to only return the chain outputs. If False,
inputs are also added to the final outputs.
Returns:
A dict of the final chain outputs.
"""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
return inputs
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = await self.memory.aload_memory_variables(inputs)
inputs = dict(inputs, **external_context)
return inputs
@property
def _run_output_key(self) -> str:
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
return self.output_keys[0]
@deprecated("0.1.0", alternative="invoke", removal="0.2.0")
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Convenience method for executing chain.
The main difference between this method and `Chain.__call__` is that this
method expects inputs to be passed directly in as positional arguments or
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
with all the inputs
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
chain.run("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
chain.run(question=question, context=context)
# -> "The temperature in Boise is..."
"""
# Run at start to make sure this is possible/defined
_output_key = self._run_output_key
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
else:
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
@deprecated("0.1.0", alternative="ainvoke", removal="0.2.0")
async def arun(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Convenience method for executing chain.
The main difference between this method and `Chain.__call__` is that this
method expects inputs to be passed directly in as positional arguments or
keyword arguments, whereas `Chain.__call__` expects a single input dictionary
with all the inputs
Args:
*args: If the chain expects a single input, it can be passed in as the
sole positional argument.
callbacks: Callbacks to use for this chain run. These will be called in
addition to callbacks passed to the chain during construction, but only
these runtime callbacks will propagate to calls to other objects.
tags: List of string tags to pass to all callbacks. These will be passed in
addition to tags passed to the chain during construction, but only
these runtime tags will propagate to calls to other objects.
**kwargs: If the chain expects multiple inputs, they can be passed in
directly as keyword arguments.
Returns:
The chain output.
Example:
.. code-block:: python
# Suppose we have a single-input chain that takes a 'question' string:
await chain.arun("What's the temperature in Boise, Idaho?")
# -> "The temperature in Boise is..."
# Suppose we have a multi-input chain that takes a 'question' string
# and 'context' string:
question = "What's the temperature in Boise, Idaho?"
context = "Weather report for Boise, Idaho on 07/03/23..."
await chain.arun(question=question, context=context)
# -> "The temperature in Boise is..."
"""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
elif args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (
await self.acall(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
if kwargs and not args:
return (
await self.acall(
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
method.
Returns:
A dictionary representation of the chain.
Example:
.. code-block:: python
chain.dict(exclude_unset=True)
# -> {"_type": "foo", "verbose": False, ...}
"""
_dict = super().dict(**kwargs)
try:
_dict["_type"] = self._chain_type
except NotImplementedError:
pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
null.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
# Fetch dictionary to save
chain_dict = self.dict()
if "_type" not in chain_dict:
raise NotImplementedError(f"Chain {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
@deprecated("0.1.0", alternative="batch", removal="0.2.0")
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
__all__ = [
"Chain",
"_get_verbosity",
]

View File

@@ -1,423 +1,5 @@
"""Chain that just formats a prompt and calls an LLM."""
from __future__ import annotations
from langchain_core.legacy.chains import LLMChain
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.language_models import (
BaseLanguageModel,
LanguageModelInput,
)
from langchain_core.load.dump import dumpd
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableWithFallbacks,
)
from langchain_core.runnables.configurable import DynamicRunnable
from langchain_core.utils.input import get_colored_text
from langchain.chains.base import Chain
@deprecated(
since="0.1.17",
alternative="RunnableSequence, e.g., `prompt | llm`",
removal="0.3.0",
)
class LLMChain(Chain):
"""Chain to run queries against LLMs.
This class is deprecated. See below for an example implementation using
LangChain runnables:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = OpenAI()
chain = prompt | llm
chain.invoke("your adjective here")
Example:
.. code-block:: python
from langchain.chains import LLMChain
from langchain_community.llms import OpenAI
from langchain_core.prompts import PromptTemplate
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(
input_variables=["adjective"], template=prompt_template
)
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@classmethod
def is_lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: Union[
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
]
"""Language model to call."""
output_key: str = "text" #: :meta private:
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
"""Output parser to use.
Defaults to one that takes the most likely string but does not change it
otherwise."""
return_final_only: bool = True
"""Whether to return only the final parsed result. Defaults to True.
If false, will return a bunch of extra information about the generation."""
llm_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
if self.return_final_only:
return [self.output_key]
else:
return [self.output_key, "full_generation"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
callbacks = run_manager.get_child() if run_manager else None
if isinstance(self.llm, BaseLanguageModel):
return self.llm.generate_prompt(
prompts,
stop,
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
async def agenerate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
callbacks = run_manager.get_child() if run_manager else None
if isinstance(self.llm, BaseLanguageModel):
return await self.llm.agenerate_prompt(
prompts,
stop,
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if len(input_list) == 0:
return [], stop
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
async def aprep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if len(input_list) == 0:
return [], stop
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
async def aapply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input_list": input_list},
)
try:
response = await self.agenerate(input_list, run_manager=run_manager)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
await run_manager.on_chain_end({"outputs": outputs})
return outputs
@property
def _run_output_key(self) -> str:
return self.output_key
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
result = [
# Get the text of the top generated string.
{
self.output_key: self.output_parser.parse_result(generation),
"full_generation": generation,
}
for generation in llm_result.generations
]
if self.return_final_only:
result = [{self.output_key: r[self.output_key]} for r in result]
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = await self.agenerate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs, callbacks=callbacks)[self.output_key]
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
"""Call predict and then parse the results."""
warnings.warn(
"The predict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
async def apredict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call apredict and then parse the results."""
warnings.warn(
"The apredict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The apply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.apply(input_list, callbacks=callbacks)
return self._parse_generation(result)
def _parse_generation(
self, generation: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key])
for res in generation
]
else:
return generation
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The aapply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.aapply(input_list, callbacks=callbacks)
return self._parse_generation(result)
@property
def _chain_type(self) -> str:
return "llm_chain"
@classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
def _get_num_tokens(self, text: str) -> int:
return _get_language_model(self.llm).get_num_tokens(text)
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
if isinstance(llm_like, BaseLanguageModel):
return llm_like
elif isinstance(llm_like, RunnableBinding):
return _get_language_model(llm_like.bound)
elif isinstance(llm_like, RunnableWithFallbacks):
return _get_language_model(llm_like.runnable)
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
return _get_language_model(llm_like.default)
else:
raise ValueError(
f"Unable to extract BaseLanguageModel from llm_like object of type "
f"{type(llm_like)}"
)
__all__ = [
"LLMChain",
]

View File

@@ -45,29 +45,28 @@ from langchain_community.chat_message_histories import (
XataChatMessageHistory,
ZepChatMessageHistory,
)
from langchain.memory.buffer import (
ConversationBufferMemory,
ConversationStringBufferMemory,
)
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.combined import CombinedMemory
from langchain.memory.entity import (
ConversationEntityMemory,
InMemoryEntityStore,
from langchain_community.memory import (
ConversationKGMemory,
MotorheadMemory,
RedisEntityStore,
SQLiteEntityStore,
UpstashRedisEntityStore,
ZepMemory,
)
from langchain_core.legacy.memory import (
CombinedMemory,
ConversationBufferMemory,
ConversationBufferWindowMemory,
ConversationEntityMemory,
ConversationStringBufferMemory,
ConversationSummaryBufferMemory,
ConversationSummaryMemory,
ConversationTokenBufferMemory,
InMemoryEntityStore,
ReadOnlySharedMemory,
SimpleMemory,
VectorStoreRetrieverMemory,
)
from langchain.memory.kg import ConversationKGMemory
from langchain.memory.motorhead_memory import MotorheadMemory
from langchain.memory.readonly import ReadOnlySharedMemory
from langchain.memory.simple import SimpleMemory
from langchain.memory.summary import ConversationSummaryMemory
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from langchain.memory.token_buffer import ConversationTokenBufferMemory
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
from langchain.memory.zep_memory import ZepMemory
__all__ = [
"AstraDBChatMessageHistory",

View File

@@ -1,136 +1,6 @@
from typing import Any, Dict, List, Optional
from langchain_core.legacy.memory import (
ConversationBufferMemory,
ConversationStringBufferMemory,
)
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.pydantic_v1 import root_validator
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key
class ConversationBufferMemory(BaseChatMemory):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
@property
def buffer(self) -> Any:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
async def abuffer(self) -> Any:
"""String buffer of memory."""
return (
await self.abuffer_as_messages()
if self.return_messages
else await self.abuffer_as_str()
)
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
return get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
return self._buffer_as_str(self.chat_memory.messages)
async def abuffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
messages = await self.chat_memory.aget_messages()
return self._buffer_as_str(messages)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return self.chat_memory.messages
async def abuffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return await self.chat_memory.aget_messages()
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
buffer = await self.abuffer()
return {self.memory_key: buffer}
class ConversationStringBufferMemory(BaseMemory):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
buffer: str = ""
output_key: Optional[str] = None
input_key: Optional[str] = None
memory_key: str = "history" #: :meta private:
@root_validator()
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that return messages is not True."""
if values.get("return_messages", False):
raise ValueError(
"return_messages must be False for ConversationStringBufferMemory"
)
return values
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return self.load_memory_variables(inputs)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer += "\n" + "\n".join([human, ai])
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
return self.save_context(inputs, outputs)
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = ""
async def aclear(self) -> None:
self.clear()
__all__ = ["ConversationBufferMemory", "ConversationStringBufferMemory"]

View File

@@ -1,47 +1,3 @@
from typing import Any, Dict, List, Union
from langchain_core.legacy.memory import ConversationBufferWindowMemory
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain.memory.chat_memory import BaseChatMemory
class ConversationBufferWindowMemory(BaseChatMemory):
"""Buffer for storing conversation memory inside a limited size window."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
k: int = 5
"""Number of messages to store in buffer."""
@property
def buffer(self) -> Union[str, List[BaseMessage]]:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is False."""
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
return get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is True."""
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
__all__ = ["ConversationBufferWindowMemory"]

View File

@@ -1,75 +1,5 @@
import warnings
from abc import ABC
from typing import Any, Dict, Optional, Tuple
from langchain_core.legacy.memory.chat_memory import BaseChatMemory
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.memory import BaseMemory
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import Field
from langchain.memory.utils import get_prompt_input_key
class BaseChatMemory(BaseMemory, ABC):
"""Abstract base class for chat memory."""
chat_memory: BaseChatMessageHistory = Field(
default_factory=InMemoryChatMessageHistory
)
output_key: Optional[str] = None
input_key: Optional[str] = None
return_messages: bool = False
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) == 1:
output_key = list(outputs.keys())[0]
elif "output" in outputs:
output_key = "output"
warnings.warn(
f"'{self.__class__.__name__}' got multiple output keys:"
f" {outputs.keys()}. The default 'output' key is being used."
f" If this is not desired, please manually set 'output_key'."
)
else:
raise ValueError(
f"Got multiple output keys: {outputs.keys()}, cannot "
f"determine which to store in memory. Please set the "
f"'output_key' explicitly."
)
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
)
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
await self.chat_memory.aadd_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
)
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
async def aclear(self) -> None:
"""Clear memory contents."""
await self.chat_memory.aclear()
__all__ = [
"BaseChatMemory",
]

View File

@@ -1,82 +1,3 @@
import warnings
from typing import Any, Dict, List, Set
from langchain_core.legacy.memory import CombinedMemory
from langchain_core.memory import BaseMemory
from langchain_core.pydantic_v1 import validator
from langchain.memory.chat_memory import BaseChatMemory
class CombinedMemory(BaseMemory):
"""Combining multiple memories' data together."""
memories: List[BaseMemory]
"""For tracking all the memories that should be accessed."""
@validator("memories")
def check_repeated_memory_variable(
cls, value: List[BaseMemory]
) -> List[BaseMemory]:
all_variables: Set[str] = set()
for val in value:
overlap = all_variables.intersection(val.memory_variables)
if overlap:
raise ValueError(
f"The same variables {overlap} are found in multiple"
"memory object, which is not allowed by CombinedMemory."
)
all_variables |= set(val.memory_variables)
return value
@validator("memories")
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]:
"""Check that if memories are of type BaseChatMemory that input keys exist."""
for val in value:
if isinstance(val, BaseChatMemory):
if val.input_key is None:
warnings.warn(
"When using CombinedMemory, "
"input keys should be so the input is known. "
f" Was not set on {val}"
)
return value
@property
def memory_variables(self) -> List[str]:
"""All the memory variables that this instance provides."""
"""Collected from the all the linked memories."""
memory_variables = []
for memory in self.memories:
memory_variables.extend(memory.memory_variables)
return memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load all vars from sub-memories."""
memory_data: Dict[str, Any] = {}
# Collect vars from all sub-memories
for memory in self.memories:
data = memory.load_memory_variables(inputs)
for key, value in data.items():
if key in memory_data:
raise ValueError(
f"The variable {key} is repeated in the CombinedMemory."
)
memory_data[key] = value
return memory_data
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this session for every memory."""
# Save context for all sub-memories
for memory in self.memories:
memory.save_context(inputs, outputs)
def clear(self) -> None:
"""Clear context from this session for every memory."""
for memory in self.memories:
memory.clear()
__all__ = ["CombinedMemory"]

View File

@@ -1,483 +1,19 @@
import logging
from abc import ABC, abstractmethod
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional
from langchain_community.utilities.redis import get_client
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
ENTITY_SUMMARIZATION_PROMPT,
from langchain_community.memory.entity import (
RedisEntityStore,
SQLiteEntityStore,
UpstashRedisEntityStore,
)
from langchain.memory.utils import get_prompt_input_key
logger = logging.getLogger(__name__)
class BaseEntityStore(BaseModel, ABC):
"""Abstract base class for Entity store."""
@abstractmethod
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
"""Get entity value from store."""
pass
@abstractmethod
def set(self, key: str, value: Optional[str]) -> None:
"""Set entity value in store."""
pass
@abstractmethod
def delete(self, key: str) -> None:
"""Delete entity value from store."""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""Check if entity exists in store."""
pass
@abstractmethod
def clear(self) -> None:
"""Delete all entities from store."""
pass
class InMemoryEntityStore(BaseEntityStore):
"""In-memory Entity store."""
store: Dict[str, Optional[str]] = {}
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
return self.store.get(key, default)
def set(self, key: str, value: Optional[str]) -> None:
self.store[key] = value
def delete(self, key: str) -> None:
del self.store[key]
def exists(self, key: str) -> bool:
return key in self.store
def clear(self) -> None:
return self.store.clear()
class UpstashRedisEntityStore(BaseEntityStore):
"""Upstash Redis backed Entity store.
Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""
def __init__(
self,
session_id: str = "default",
url: str = "",
token: str = "",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
from upstash_redis import Redis
except ImportError:
raise ImportError(
"Could not import upstash_redis python package. "
"Please install it with `pip install upstash_redis`."
)
super().__init__(*args, **kwargs)
try:
self.redis_client = Redis(url=url, token=token)
except Exception:
logger.error("Upstash Redis instance could not be initiated.")
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl
@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
def exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
def clear(self) -> None:
def scan_and_delete(cursor: int) -> int:
cursor, keys_to_delete = self.redis_client.scan(
cursor, f"{self.full_key_prefix}:*"
)
self.redis_client.delete(*keys_to_delete)
return cursor
cursor = scan_and_delete(0)
while cursor != 0:
scan_and_delete(cursor)
class RedisEntityStore(BaseEntityStore):
"""Redis-backed Entity store.
Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""
redis_client: Any
session_id: str = "default"
key_prefix: str = "memory_store"
ttl: Optional[int] = 60 * 60 * 24
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
def __init__(
self,
session_id: str = "default",
url: str = "redis://localhost:6379/0",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
import redis
except ImportError:
raise ImportError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
super().__init__(*args, **kwargs)
try:
self.redis_client = get_client(redis_url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl
@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
def exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
def clear(self) -> None:
# iterate a list in batches of size batch_size
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch
for keybatch in batched(
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
):
self.redis_client.delete(*keybatch)
class SQLiteEntityStore(BaseEntityStore):
"""SQLite-backed Entity store"""
session_id: str = "default"
table_name: str = "memory_store"
conn: Any = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(
self,
session_id: str = "default",
db_file: str = "entities.db",
table_name: str = "memory_store",
*args: Any,
**kwargs: Any,
):
try:
import sqlite3
except ImportError:
raise ImportError(
"Could not import sqlite3 python package. "
"Please install it with `pip install sqlite3`."
)
super().__init__(*args, **kwargs)
self.conn = sqlite3.connect(db_file)
self.session_id = session_id
self.table_name = table_name
self._create_table_if_not_exists()
@property
def full_table_name(self) -> str:
return f"{self.table_name}_{self.session_id}"
def _create_table_if_not_exists(self) -> None:
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
key TEXT PRIMARY KEY,
value TEXT
)
"""
with self.conn:
self.conn.execute(create_table_query)
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
query = f"""
SELECT value
FROM {self.full_table_name}
WHERE key = ?
"""
cursor = self.conn.execute(query, (key,))
result = cursor.fetchone()
if result is not None:
value = result[0]
return value
return default
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
query = f"""
INSERT OR REPLACE INTO {self.full_table_name} (key, value)
VALUES (?, ?)
"""
with self.conn:
self.conn.execute(query, (key, value))
def delete(self, key: str) -> None:
query = f"""
DELETE FROM {self.full_table_name}
WHERE key = ?
"""
with self.conn:
self.conn.execute(query, (key,))
def exists(self, key: str) -> bool:
query = f"""
SELECT 1
FROM {self.full_table_name}
WHERE key = ?
LIMIT 1
"""
cursor = self.conn.execute(query, (key,))
result = cursor.fetchone()
return result is not None
def clear(self) -> None:
query = f"""
DELETE FROM {self.full_table_name}
"""
with self.conn:
self.conn.execute(query)
class ConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer memory.
Extracts named entities from the recent chat history and generates summaries.
With a swappable entity store, persisting entities across conversations.
Defaults to an in-memory entity store, and can be swapped out for a Redis,
SQLite, or other entity store.
"""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
# Cache of recently detected entity names, if any
# It is updated when load_memory_variables is called:
entity_cache: List[str] = []
# Number of recent message pairs to consider when updating entities:
k: int = 3
chat_history_key: str = "history"
# Store to manage entity-related data:
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
@property
def buffer(self) -> List[BaseMessage]:
"""Access chat memory messages."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return ["entities", self.chat_history_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Returns chat history and all generated entities with summaries if available,
and updates or clears the recent entity cache.
New entity name can be found when calling this method, before the entity
summaries are generated, so the entity cache values may be empty if no entity
descriptions are generated yet.
"""
# Create an LLMChain for predicting entity names from the recent chat history:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
# Extract an arbitrary window of the last message pairs from
# the chat history, where the hyperparameter k is the
# number of message pairs:
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
# Generates a comma-separated list of named entities,
# e.g. "Jane, White House, UFO"
# or "NONE" if no named entities are extracted:
output = chain.predict(
history=buffer_string,
input=inputs[prompt_input_key],
)
# If no named entities are extracted, assigns an empty list.
if output.strip() == "NONE":
entities = []
else:
# Make a list of the extracted entities:
entities = [w.strip() for w in output.split(",")]
# Make a dictionary of entities with summary if exists:
entity_summaries = {}
for entity in entities:
entity_summaries[entity] = self.entity_store.get(entity, "")
# Replaces the entity name cache with the most recently discussed entities,
# or if no entities were extracted, clears the cache:
self.entity_cache = entities
# Should we return as message objects or as a string?
if self.return_messages:
# Get last `k` pair of chat messages:
buffer: Any = self.buffer[-self.k * 2 :]
else:
# Reuse the string we made earlier:
buffer = buffer_string
return {
self.chat_history_key: buffer,
"entities": entity_summaries,
}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""
Save context from this conversation history to the entity store.
Generates a summary for each entity in the entity cache by prompting
the model, and saves these summaries to the entity store.
"""
super().save_context(inputs, outputs)
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
# Extract an arbitrary window of the last message pairs from
# the chat history, where the hyperparameter k is the
# number of message pairs:
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
input_data = inputs[prompt_input_key]
# Create an LLMChain for predicting entity summarization from the context
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
# Generate new summaries for entities and save them in the entity store
for entity in self.entity_cache:
# Get existing summary if it exists
existing_summary = self.entity_store.get(entity, "")
output = chain.predict(
summary=existing_summary,
entity=entity,
history=buffer_string,
input=input_data,
)
# Save the updated summary to the entity store
self.entity_store.set(entity, output.strip())
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.entity_cache.clear()
self.entity_store.clear()
from langchain_core.legacy.memory.entity import (
BaseEntityStore,
ConversationEntityMemory,
InMemoryEntityStore,
)
__all__ = [
"BaseEntityStore",
"InMemoryEntityStore",
"UpstashRedisEntityStore",
"RedisEntityStore",
"SQLiteEntityStore",
"ConversationEntityMemory",
]

View File

@@ -1,133 +1,5 @@
from typing import Any, Dict, List, Type, Union
from langchain_community.memory import ConversationKGMemory
from langchain_community.graphs import NetworkxEntityGraph
from langchain_community.graphs.networkx_graph import (
KnowledgeTriple,
get_entities,
parse_triples,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
class ConversationKGMemory(BaseChatMemory):
"""Knowledge graph conversation memory.
Integrates with external knowledge graph to store and retrieve
information about knowledge triples in the conversation.
"""
k: int = 2
human_prefix: str = "Human"
ai_prefix: str = "AI"
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLanguageModel
summary_message_cls: Type[BaseMessage] = SystemMessage
"""Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
entities = self._get_current_entities(inputs)
summary_strings = []
for entity in entities:
knowledge = self.kg.get_entity_knowledge(entity)
if knowledge:
summary = f"On {entity}: {'. '.join(knowledge)}."
summary_strings.append(summary)
context: Union[str, List]
if not summary_strings:
context = [] if self.return_messages else ""
elif self.return_messages:
context = [
self.summary_message_cls(content=text) for text in summary_strings
]
else:
context = "\n".join(summary_strings)
return {self.memory_key: context}
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
"""Get the output key for the prompt."""
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
return list(outputs.keys())[0]
return self.output_key
def get_current_entities(self, input_string: str) -> List[str]:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
)
return get_entities(output)
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""Get the current entities in the conversation."""
prompt_input_key = self._get_prompt_input_key(inputs)
return self.get_current_entities(inputs[prompt_input_key])
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
verbose=True,
)
knowledge = parse_triples(output)
return knowledge
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""Get and update knowledge graph from the conversation history."""
prompt_input_key = self._get_prompt_input_key(inputs)
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
for triple in knowledge:
self.kg.add_triple(triple)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self._get_and_update_kg(inputs)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.kg.clear()
__all__ = [
"ConversationKGMemory",
]

View File

@@ -1,94 +1,5 @@
from typing import Any, Dict, List, Optional
from langchain_community.memory import MotorheadMemory
import requests
from langchain_core.messages import get_buffer_string
from langchain.memory.chat_memory import BaseChatMemory
MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
# LOCAL_URL = "http://localhost:8080"
class MotorheadMemory(BaseChatMemory):
"""Chat message memory backed by Motorhead service."""
url: str = MANAGED_URL
timeout: int = 3000
memory_key: str = "history"
session_id: str
context: Optional[str] = None
# Managed Params
api_key: Optional[str] = None
client_id: Optional[str] = None
def __get_headers(self) -> Dict[str, str]:
is_managed = self.url == MANAGED_URL
headers = {
"Content-Type": "application/json",
}
if is_managed and not (self.api_key and self.client_id):
raise ValueError(
"""
You must provide an API key or a client ID to use the managed
version of Motorhead. Visit https://getmetal.io for more information.
"""
)
if is_managed and self.api_key and self.client_id:
headers["x-metal-api-key"] = self.api_key
headers["x-metal-client-id"] = self.client_id
return headers
async def init(self) -> None:
res = requests.get(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
headers=self.__get_headers(),
)
res_data = res.json()
res_data = res_data.get("data", res_data) # Handle Managed Version
messages = res_data.get("messages", [])
context = res_data.get("context", "NONE")
for message in reversed(messages):
if message["role"] == "AI":
self.chat_memory.add_ai_message(message["content"])
else:
self.chat_memory.add_user_message(message["content"])
if context and context != "NONE":
self.context = context
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
if self.return_messages:
return {self.memory_key: self.chat_memory.messages}
else:
return {self.memory_key: get_buffer_string(self.chat_memory.messages)}
@property
def memory_variables(self) -> List[str]:
return [self.memory_key]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
input_str, output_str = self._get_input_output(inputs, outputs)
requests.post(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
json={
"messages": [
{"role": "Human", "content": f"{input_str}"},
{"role": "AI", "content": f"{output_str}"},
]
},
headers=self.__get_headers(),
)
super().save_context(inputs, outputs)
def delete_session(self) -> None:
"""Delete a session"""
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")
__all__ = [
"MotorheadMemory",
]

View File

@@ -1,165 +1,15 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE = """You are an assistant to a human, powered by a large language model trained by OpenAI.
You are designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, you are able to generate human-like text based on the input you receive, allowing you to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
You are constantly learning and improving, and your capabilities are constantly evolving. You are able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. You have access to some personalized information provided by the human in the Context section below. Additionally, you are able to generate your own text based on the input you receive, allowing you to engage in discussions and provide explanations and descriptions on a wide range of topics.
Overall, you are a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether the human needs help with a specific question or just wants to have a conversation about a particular topic, you are here to assist.
Context:
{entities}
Current conversation:
{history}
Last line:
Human: {input}
You:"""
ENTITY_MEMORY_CONVERSATION_TEMPLATE = PromptTemplate(
input_variables=["entities", "history", "input"],
template=_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE,
from langchain_core.legacy.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
ENTITY_MEMORY_CONVERSATION_TEMPLATE,
ENTITY_SUMMARIZATION_PROMPT,
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
SUMMARY_PROMPT,
)
_DEFAULT_SUMMARIZER_TEMPLATE = """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
EXAMPLE
Current summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
New lines of conversation:
Human: Why do you think artificial intelligence is a force for good?
AI: Because artificial intelligence will help humans reach their full potential.
New summary:
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
END OF EXAMPLE
Current summary:
{summary}
New lines of conversation:
{new_lines}
New summary:"""
SUMMARY_PROMPT = PromptTemplate(
input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE
)
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
The conversation history is provided just in case of a coreference (e.g. "What do you know about him" where "him" is defined in a previous line) -- ignore items mentioned there that are not in the last line.
Return the output as a single comma-separated list, or NONE if there is nothing of note to return (e.g. the user is just issuing a greeting or having a simple conversation).
EXAMPLE
Conversation history:
Person #1: how's it going today?
AI: "It's going great! How about you?"
Person #1: good! busy working on Langchain. lots to do.
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
Last line:
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
Output: Langchain
END OF EXAMPLE
EXAMPLE
Conversation history:
Person #1: how's it going today?
AI: "It's going great! How about you?"
Person #1: good! busy working on Langchain. lots to do.
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
Last line:
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Person #2.
Output: Langchain, Person #2
END OF EXAMPLE
Conversation history (for reference only):
{history}
Last line of conversation (for extraction):
Human: {input}
Output:"""
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
)
_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence.
The update should only include facts that are relayed in the last line of conversation about the provided entity, and should only contain facts about the provided entity.
If there is no new information about the provided entity or the information is not worth noting (not an important or relevant fact to remember long-term), return the existing summary unchanged.
Full conversation history (for context):
{history}
Entity to summarize:
{entity}
Existing summary of {entity}:
{summary}
Last line of conversation:
Human: {input}
Updated summary:"""
ENTITY_SUMMARIZATION_PROMPT = PromptTemplate(
input_variables=["entity", "summary", "history", "input"],
template=_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE,
)
KG_TRIPLE_DELIMITER = "<|>"
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
"You are a networked intelligence helping a human track knowledge triples"
" about all relevant people, things, concepts, etc. and integrating"
" them with your knowledge stored within your weights"
" as well as that stored in a knowledge graph."
" Extract all of the knowledge triples from the last line of conversation."
" A knowledge triple is a clause that contains a subject, a predicate,"
" and an object. The subject is the entity being described,"
" the predicate is the property of the subject that is being"
" described, and the object is the value of the property.\n\n"
"EXAMPLE\n"
"Conversation history:\n"
"Person #1: Did you hear aliens landed in Area 51?\n"
"AI: No, I didn't hear that. What do you know about Area 51?\n"
"Person #1: It's a secret military base in Nevada.\n"
"AI: What do you know about Nevada?\n"
"Last line of conversation:\n"
"Person #1: It's a state in the US. It's also the number 1 producer of gold in the US.\n\n"
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"Conversation history:\n"
"Person #1: Hello.\n"
"AI: Hi! How are you?\n"
"Person #1: I'm good. How are you?\n"
"AI: I'm good too.\n"
"Last line of conversation:\n"
"Person #1: I'm going to the store.\n\n"
"Output: NONE\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"Conversation history:\n"
"Person #1: What do you know about Descartes?\n"
"AI: Descartes was a French philosopher, mathematician, and scientist who lived in the 17th century.\n"
"Person #1: The Descartes I'm referring to is a standup comedian and interior designer from Montreal.\n"
"AI: Oh yes, He is a comedian and an interior designer. He has been in the industry for 30 years. His favorite food is baked bean pie.\n"
"Last line of conversation:\n"
"Person #1: Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n"
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n"
"END OF EXAMPLE\n\n"
"Conversation history (for reference only):\n"
"{history}"
"\nLast line of conversation (for extraction):\n"
"Human: {input}\n\n"
"Output:"
)
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["history", "input"],
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
)
__all__ = [
"ENTITY_SUMMARIZATION_PROMPT",
"ENTITY_EXTRACTION_PROMPT",
"KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT",
"ENTITY_MEMORY_CONVERSATION_TEMPLATE",
"SUMMARY_PROMPT",
]

View File

@@ -1,26 +1,5 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory import ReadOnlySharedMemory
from langchain_core.memory import BaseMemory
class ReadOnlySharedMemory(BaseMemory):
"""A memory wrapper that is read-only and cannot be changed."""
memory: BaseMemory
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return self.memory.memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load memory variables from memory."""
return self.memory.load_memory_variables(inputs)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass
__all__ = [
"ReadOnlySharedMemory",
]

View File

@@ -1,26 +1,3 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory import SimpleMemory
from langchain_core.memory import BaseMemory
class SimpleMemory(BaseMemory):
"""Simple memory for storing context or other information that shouldn't
ever change between prompts.
"""
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed, my memory is set in stone."""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass
__all__ = ["SimpleMemory"]

View File

@@ -1,98 +1,11 @@
from __future__ import annotations
from typing import Any, Dict, List, Type
from langchain_core.legacy.memory.summary import (
ConversationSummaryMemory,
SummarizerMixin,
)
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
class SummarizerMixin(BaseModel):
"""Mixin for summarizer."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
prompt: BasePromptTemplate = SUMMARY_PROMPT
summary_message_cls: Type[BaseMessage] = SystemMessage
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.prompt)
return chain.predict(summary=existing_summary, new_lines=new_lines)
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""Conversation summarizer to chat memory."""
buffer: str = ""
memory_key: str = "history" #: :meta private:
@classmethod
def from_messages(
cls,
llm: BaseLanguageModel,
chat_memory: BaseChatMessageHistory,
*,
summarize_step: int = 2,
**kwargs: Any,
) -> ConversationSummaryMemory:
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
for i in range(0, len(obj.chat_memory.messages), summarize_step):
obj.buffer = obj.predict_new_summary(
obj.chat_memory.messages[i : i + summarize_step], obj.buffer
)
return obj
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
if self.return_messages:
buffer: Any = [self.summary_message_cls(content=self.buffer)]
else:
buffer = self.buffer
return {self.memory_key: buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables
expected_keys = {"summary", "new_lines"}
if expected_keys != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but it should have {expected_keys}."
)
return values
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self.buffer = self.predict_new_summary(
self.chat_memory.messages[-2:], self.buffer
)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.buffer = ""
__all__ = [
"ConversationSummaryMemory",
"SummarizerMixin",
]

View File

@@ -1,78 +1,3 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory import ConversationSummaryBufferMemory
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.pydantic_v1 import root_validator
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
"""Buffer with summarizer for storing conversation memory."""
max_token_limit: int = 2000
moving_summary_buffer: str = ""
memory_key: str = "history"
@property
def buffer(self) -> List[BaseMessage]:
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
buffer = self.buffer
if self.moving_summary_buffer != "":
first_messages: List[BaseMessage] = [
self.summary_message_cls(content=self.moving_summary_buffer)
]
buffer = first_messages + buffer
if self.return_messages:
final_buffer: Any = buffer
else:
final_buffer = get_buffer_string(
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
)
return {self.memory_key: final_buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables
expected_keys = {"summary", "new_lines"}
if expected_keys != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but it should have {expected_keys}."
)
return values
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self.prune()
def prune(self) -> None:
"""Prune buffer if it exceeds max token limit"""
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
self.moving_summary_buffer = self.predict_new_summary(
pruned_memory, self.moving_summary_buffer
)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.moving_summary_buffer = ""
__all__ = ["ConversationSummaryBufferMemory"]

View File

@@ -1,59 +1,5 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory import ConversationTokenBufferMemory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain.memory.chat_memory import BaseChatMemory
class ConversationTokenBufferMemory(BaseChatMemory):
"""Conversation chat memory with token limit."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
memory_key: str = "history"
max_token_limit: int = 2000
@property
def buffer(self) -> Any:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is False."""
return get_buffer_string(
self.chat_memory.messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is True."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer. Pruned."""
super().save_context(inputs, outputs)
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
__all__ = [
"ConversationTokenBufferMemory",
]

View File

@@ -1,20 +1,3 @@
from typing import Any, Dict, List
from langchain_core.legacy.memory.utils import get_prompt_input_key
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""
Get the prompt input key.
Args:
inputs: Dict[str, Any]
memory_variables: List[str]
Returns:
A prompt input key.
"""
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
return prompt_input_keys[0]
__all__ = ["get_prompt_input_key"]

View File

@@ -1,101 +1,5 @@
"""Class for a VectorStore-backed memory object."""
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.legacy.memory import VectorStoreRetrieverMemory
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.memory.chat_memory import BaseMemory
from langchain.memory.utils import get_prompt_input_key
class VectorStoreRetrieverMemory(BaseMemory):
"""VectorStoreRetriever-backed memory."""
retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStoreRetriever object to connect to."""
memory_key: str = "history" #: :meta private:
"""Key name to locate the memories in the result of load_memory_variables."""
input_key: Optional[str] = None
"""Key name to index the inputs to load_memory_variables."""
return_docs: bool = False
"""Whether or not to return the result of querying the database directly."""
exclude_input_keys: Sequence[str] = Field(default_factory=tuple)
"""Input keys to exclude in addition to memory key when constructing the document"""
@property
def memory_variables(self) -> List[str]:
"""The list of keys emitted from the load_memory_variables method."""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _documents_to_memory_variables(
self, docs: List[Document]
) -> Dict[str, Union[List[Document], str]]:
result: Union[List[Document], str]
if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs])
else:
result = docs
return {self.memory_key: result}
def load_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.invoke(query)
return self._documents_to_memory_variables(docs)
async def aload_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = await self.retriever.ainvoke(query)
return self._documents_to_memory_variables(docs)
def _form_documents(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> List[Document]:
"""Format context from this conversation to buffer."""
# Each document should only include the current turn, not the chat history
exclude = set(self.exclude_input_keys)
exclude.add(self.memory_key)
filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude}
texts = [
f"{k}: {v}"
for k, v in list(filtered_inputs.items()) + list(outputs.items())
]
page_content = "\n".join(texts)
return [Document(page_content=page_content)]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
documents = self._form_documents(inputs, outputs)
self.retriever.add_documents(documents)
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
documents = self._form_documents(inputs, outputs)
await self.retriever.aadd_documents(documents)
def clear(self) -> None:
"""Nothing to clear."""
async def aclear(self) -> None:
"""Nothing to clear."""
__all__ = ["VectorStoreRetrieverMemory"]

View File

@@ -1,125 +1,3 @@
from __future__ import annotations
from langchain_community.memory import ZepMemory
from typing import Any, Dict, Optional
from langchain_community.chat_message_histories import ZepChatMessageHistory
from langchain.memory import ConversationBufferMemory
class ZepMemory(ConversationBufferMemory):
"""Persist your chain history to the Zep MemoryStore.
The number of messages returned by Zep and when the Zep server summarizes chat
histories is configurable. See the Zep documentation for more details.
Documentation: https://docs.getzep.com
Example:
.. code-block:: python
memory = ZepMemory(
session_id=session_id, # Identifies your user or a user's session
url=ZEP_API_URL, # Your Zep server's URL
api_key=<your_api_key>, # Optional
memory_key="history", # Ensure this matches the key used in
# chain's prompt template
return_messages=True, # Does your prompt template expect a string
# or a list of Messages?
)
chain = LLMChain(memory=memory,...) # Configure your chain to use the ZepMemory
instance
Note:
To persist metadata alongside your chat history, your will need to create a
custom Chain class that overrides the `prep_outputs` method to include the metadata
in the call to `self.memory.save_context`.
Zep - Fast, scalable building blocks for LLM Apps
=========
Zep is an open source platform for productionizing LLM apps. Go from a prototype
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
rewriting code.
For server installation instructions and more, see:
https://docs.getzep.com/deployment/quickstart/
For more information on the zep-python package, see:
https://github.com/getzep/zep-python
"""
chat_memory: ZepChatMessageHistory
def __init__(
self,
session_id: str,
url: str = "http://localhost:8000",
api_key: Optional[str] = None,
output_key: Optional[str] = None,
input_key: Optional[str] = None,
return_messages: bool = False,
human_prefix: str = "Human",
ai_prefix: str = "AI",
memory_key: str = "history",
):
"""Initialize ZepMemory.
Args:
session_id (str): Identifies your user or a user's session
url (str, optional): Your Zep server's URL. Defaults to
"http://localhost:8000".
api_key (Optional[str], optional): Your Zep API key. Defaults to None.
output_key (Optional[str], optional): The key to use for the output message.
Defaults to None.
input_key (Optional[str], optional): The key to use for the input message.
Defaults to None.
return_messages (bool, optional): Does your prompt template expect a string
or a list of Messages? Defaults to False
i.e. return a string.
human_prefix (str, optional): The prefix to use for human messages.
Defaults to "Human".
ai_prefix (str, optional): The prefix to use for AI messages.
Defaults to "AI".
memory_key (str, optional): The key to use for the memory.
Defaults to "history".
Ensure that this matches the key used in
chain's prompt template.
"""
chat_message_history = ZepChatMessageHistory(
session_id=session_id,
url=url,
api_key=api_key,
)
super().__init__(
chat_memory=chat_message_history,
output_key=output_key,
input_key=input_key,
return_messages=return_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
memory_key=memory_key,
)
def save_context(
self,
inputs: Dict[str, Any],
outputs: Dict[str, str],
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save context from this conversation to buffer.
Args:
inputs (Dict[str, Any]): The inputs to the chain.
outputs (Dict[str, str]): The outputs from the chain.
metadata (Optional[Dict[str, Any]], optional): Any metadata to save with
the context. Defaults to None
Returns:
None
"""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str, metadata=metadata)
self.chat_memory.add_ai_message(output_str, metadata=metadata)
__all__ = ["ZepMemory"]