mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
4 Commits
wfh/use_na
...
eugene/ent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bef75b862e | ||
|
|
88e26e8125 | ||
|
|
b2bada02c1 | ||
|
|
a4f6b91973 |
268
libs/community/langchain_community/memory/entity.py
Normal file
268
libs/community/langchain_community/memory/entity.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import logging
|
||||
from itertools import islice
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from langchain_core.memory import BaseEntityStore
|
||||
|
||||
from langchain_community.utilities.redis import get_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpstashRedisEntityStore(BaseEntityStore):
|
||||
"""Upstash Redis backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
def scan_and_delete(cursor: int) -> int:
|
||||
cursor, keys_to_delete = self.redis_client.scan(
|
||||
cursor, f"{self.full_key_prefix}:*"
|
||||
)
|
||||
self.redis_client.delete(*keys_to_delete)
|
||||
return cursor
|
||||
|
||||
cursor = scan_and_delete(0)
|
||||
while cursor != 0:
|
||||
scan_and_delete(cursor)
|
||||
|
||||
|
||||
class RedisEntityStore(BaseEntityStore):
|
||||
"""Redis-backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
redis_client: Any
|
||||
session_id: str = "default"
|
||||
key_prefix: str = "memory_store"
|
||||
ttl: Optional[int] = 60 * 60 * 24
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "redis://localhost:6379/0",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||
except redis.exceptions.ConnectionError as error:
|
||||
logger.error(error)
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
# iterate a list in batches of size batch_size
|
||||
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
||||
iterator = iter(iterable)
|
||||
while batch := list(islice(iterator, batch_size)):
|
||||
yield batch
|
||||
|
||||
for keybatch in batched(
|
||||
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
|
||||
):
|
||||
self.redis_client.delete(*keybatch)
|
||||
|
||||
|
||||
class SQLiteEntityStore(BaseEntityStore):
|
||||
"""SQLite-backed Entity store"""
|
||||
|
||||
session_id: str = "default"
|
||||
table_name: str = "memory_store"
|
||||
conn: Any = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
db_file: str = "entities.db",
|
||||
table_name: str = "memory_store",
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import sqlite3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sqlite3 python package. "
|
||||
"Please install it with `pip install sqlite3`."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conn = sqlite3.connect(db_file)
|
||||
self.session_id = session_id
|
||||
self.table_name = table_name
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
@property
|
||||
def full_table_name(self) -> str:
|
||||
return f"{self.table_name}_{self.session_id}"
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(create_table_query)
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
query = f"""
|
||||
SELECT value
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
if result is not None:
|
||||
value = result[0]
|
||||
return value
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
query = f"""
|
||||
INSERT OR REPLACE INTO {self.full_table_name} (key, value)
|
||||
VALUES (?, ?)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key, value))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key,))
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
query = f"""
|
||||
SELECT 1
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
def clear(self) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query)
|
||||
@@ -1,18 +1,11 @@
|
||||
"""**Memory** maintains Chain state, incorporating context from past runs.
|
||||
|
||||
**Class hierarchy for Memory:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseMemory --> <name>Memory --> <name>Memory # Examples: BaseChatMemory -> MotorheadMemory
|
||||
|
||||
""" # noqa: E501
|
||||
"""Memory classes that help store and retrieve various bits of information."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
@@ -81,3 +74,53 @@ class BaseMemory(Serializable, ABC):
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await run_in_executor(None, self.clear)
|
||||
|
||||
|
||||
class BaseEntityStore(BaseModel, ABC):
|
||||
"""Abstract base class for Entity store."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""Get entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
"""Set entity value in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if entity exists in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Delete all entities from store."""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryEntityStore(BaseEntityStore):
|
||||
"""In-memory Entity store."""
|
||||
|
||||
store: Dict[str, Optional[str]] = {}
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
return self.store.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
del self.store[key]
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self.store
|
||||
|
||||
def clear(self) -> None:
|
||||
return self.store.clear()
|
||||
|
||||
@@ -45,6 +45,12 @@ from langchain_community.chat_message_histories import (
|
||||
XataChatMessageHistory,
|
||||
ZepChatMessageHistory,
|
||||
)
|
||||
from langchain_community.memory.entity import (
|
||||
RedisEntityStore,
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
)
|
||||
from langchain_core.memory import InMemoryEntityStore
|
||||
|
||||
from langchain.memory.buffer import (
|
||||
ConversationBufferMemory,
|
||||
@@ -54,10 +60,6 @@ from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
||||
from langchain.memory.combined import CombinedMemory
|
||||
from langchain.memory.entity import (
|
||||
ConversationEntityMemory,
|
||||
InMemoryEntityStore,
|
||||
RedisEntityStore,
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
)
|
||||
from langchain.memory.kg import ConversationKGMemory
|
||||
from langchain.memory.motorhead_memory import MotorheadMemory
|
||||
@@ -86,23 +88,23 @@ __all__ = [
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"InMemoryEntityStore",
|
||||
"MomentoChatMessageHistory",
|
||||
"MongoDBChatMessageHistory",
|
||||
"MotorheadMemory",
|
||||
"PostgresChatMessageHistory",
|
||||
"ReadOnlySharedMemory",
|
||||
"RedisChatMessageHistory",
|
||||
"RedisEntityStore",
|
||||
"SingleStoreDBChatMessageHistory",
|
||||
"SQLChatMessageHistory",
|
||||
"SQLiteEntityStore",
|
||||
"SimpleMemory",
|
||||
"StreamlitChatMessageHistory",
|
||||
"VectorStoreRetrieverMemory",
|
||||
"XataChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
"ZepMemory",
|
||||
"UpstashRedisEntityStore",
|
||||
"UpstashRedisChatMessageHistory",
|
||||
"RedisEntityStore",
|
||||
"SQLiteEntityStore",
|
||||
"UpstashRedisEntityStore",
|
||||
"InMemoryEntityStore",
|
||||
]
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import islice
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_community.utilities.redis import get_client
|
||||
from langchain_community.memory.entity import SQLiteEntityStore, UpstashRedisEntityStore
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.memory import BaseEntityStore, InMemoryEntityStore
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
@@ -17,317 +15,6 @@ from langchain.memory.prompt import (
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseEntityStore(BaseModel, ABC):
|
||||
"""Abstract base class for Entity store."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""Get entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
"""Set entity value in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete entity value from store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if entity exists in store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Delete all entities from store."""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryEntityStore(BaseEntityStore):
|
||||
"""In-memory Entity store."""
|
||||
|
||||
store: Dict[str, Optional[str]] = {}
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
return self.store.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
del self.store[key]
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self.store
|
||||
|
||||
def clear(self) -> None:
|
||||
return self.store.clear()
|
||||
|
||||
|
||||
class UpstashRedisEntityStore(BaseEntityStore):
|
||||
"""Upstash Redis backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception:
|
||||
logger.error("Upstash Redis instance could not be initiated.")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"Upstash Redis MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"Redis MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
def scan_and_delete(cursor: int) -> int:
|
||||
cursor, keys_to_delete = self.redis_client.scan(
|
||||
cursor, f"{self.full_key_prefix}:*"
|
||||
)
|
||||
self.redis_client.delete(*keys_to_delete)
|
||||
return cursor
|
||||
|
||||
cursor = scan_and_delete(0)
|
||||
while cursor != 0:
|
||||
scan_and_delete(cursor)
|
||||
|
||||
|
||||
class RedisEntityStore(BaseEntityStore):
|
||||
"""Redis-backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
redis_client: Any
|
||||
session_id: str = "default"
|
||||
key_prefix: str = "memory_store"
|
||||
ttl: Optional[int] = 60 * 60 * 24
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "redis://localhost:6379/0",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: Optional[int] = 60 * 60 * 24,
|
||||
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||
except redis.exceptions.ConnectionError as error:
|
||||
logger.error(error)
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
||||
return res
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
||||
)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
def clear(self) -> None:
|
||||
# iterate a list in batches of size batch_size
|
||||
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
||||
iterator = iter(iterable)
|
||||
while batch := list(islice(iterator, batch_size)):
|
||||
yield batch
|
||||
|
||||
for keybatch in batched(
|
||||
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
|
||||
):
|
||||
self.redis_client.delete(*keybatch)
|
||||
|
||||
|
||||
class SQLiteEntityStore(BaseEntityStore):
|
||||
"""SQLite-backed Entity store"""
|
||||
|
||||
session_id: str = "default"
|
||||
table_name: str = "memory_store"
|
||||
conn: Any = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
db_file: str = "entities.db",
|
||||
table_name: str = "memory_store",
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import sqlite3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sqlite3 python package. "
|
||||
"Please install it with `pip install sqlite3`."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.conn = sqlite3.connect(db_file)
|
||||
self.session_id = session_id
|
||||
self.table_name = table_name
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
@property
|
||||
def full_table_name(self) -> str:
|
||||
return f"{self.table_name}_{self.session_id}"
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(create_table_query)
|
||||
|
||||
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
query = f"""
|
||||
SELECT value
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
if result is not None:
|
||||
value = result[0]
|
||||
return value
|
||||
return default
|
||||
|
||||
def set(self, key: str, value: Optional[str]) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
query = f"""
|
||||
INSERT OR REPLACE INTO {self.full_table_name} (key, value)
|
||||
VALUES (?, ?)
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key, value))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query, (key,))
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
query = f"""
|
||||
SELECT 1
|
||||
FROM {self.full_table_name}
|
||||
WHERE key = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor = self.conn.execute(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
def clear(self) -> None:
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
"""
|
||||
with self.conn:
|
||||
self.conn.execute(query)
|
||||
|
||||
|
||||
class ConversationEntityMemory(BaseChatMemory):
|
||||
"""Entity extractor & summarizer memory.
|
||||
@@ -481,3 +168,10 @@ class ConversationEntityMemory(BaseChatMemory):
|
||||
self.chat_memory.clear()
|
||||
self.entity_cache.clear()
|
||||
self.entity_store.clear()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConversationEntityMemory",
|
||||
"UpstashRedisEntityStore",
|
||||
"SQLiteEntityStore",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user