mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-04-27 11:21:34 +00:00
feat(scripts): Wipe qdrant and obtain db Stats command (#1783)
This commit is contained in:
parent
b3b0140e24
commit
ea153fb92f
3
Makefile
3
Makefile
@ -51,6 +51,9 @@ api-docs:
|
||||
ingest:
|
||||
@poetry run python scripts/ingest_folder.py $(call args)
|
||||
|
||||
stats:
|
||||
poetry run python scripts/utils.py stats
|
||||
|
||||
wipe:
|
||||
poetry run python scripts/utils.py wipe
|
||||
|
||||
|
199
scripts/utils.py
199
scripts/utils.py
@ -1,26 +1,12 @@
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from private_gpt.paths import local_data_path
|
||||
from private_gpt.settings.settings import settings
|
||||
|
||||
|
||||
def wipe() -> None:
|
||||
WIPE_MAP = {
|
||||
"simple": wipe_simple, # node store
|
||||
"chroma": wipe_chroma, # vector store
|
||||
"postgres": wipe_postgres, # node, index and vector store
|
||||
}
|
||||
for dbtype in ("nodestore", "vectorstore"):
|
||||
database = getattr(settings(), dbtype).database
|
||||
func = WIPE_MAP.get(database)
|
||||
if func:
|
||||
func(dbtype)
|
||||
else:
|
||||
print(f"Unable to wipe database '{database}' for '{dbtype}'")
|
||||
|
||||
|
||||
def wipe_file(file: str) -> None:
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
@ -50,62 +36,149 @@ def wipe_tree(path: str) -> None:
|
||||
continue
|
||||
|
||||
|
||||
def wipe_simple(dbtype: str) -> None:
|
||||
assert dbtype == "nodestore"
|
||||
from llama_index.core.storage.docstore.types import (
|
||||
DEFAULT_PERSIST_FNAME as DOCSTORE,
|
||||
)
|
||||
from llama_index.core.storage.index_store.types import (
|
||||
DEFAULT_PERSIST_FNAME as INDEXSTORE,
|
||||
)
|
||||
class Postgres:
|
||||
tables: ClassVar[dict[str, list[str]]] = {
|
||||
"nodestore": ["data_docstore", "data_indexstore"],
|
||||
"vectorstore": ["data_embeddings"],
|
||||
}
|
||||
|
||||
for store in (DOCSTORE, INDEXSTORE):
|
||||
wipe_file(str((local_data_path / store).absolute()))
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
import psycopg2
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Postgres dependencies not found") from None
|
||||
|
||||
|
||||
def wipe_postgres(dbtype: str) -> None:
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError as e:
|
||||
raise ImportError("Postgres dependencies not found") from e
|
||||
|
||||
cur = conn = None
|
||||
try:
|
||||
tables = {
|
||||
"nodestore": ["data_docstore", "data_indexstore"],
|
||||
"vectorstore": ["data_embeddings"],
|
||||
}[dbtype]
|
||||
connection = settings().postgres.model_dump(exclude_none=True)
|
||||
schema = connection.pop("schema_name")
|
||||
conn = psycopg2.connect(**connection)
|
||||
cur = conn.cursor()
|
||||
for table in tables:
|
||||
sql = f"DROP TABLE IF EXISTS {schema}.{table}"
|
||||
cur.execute(sql)
|
||||
print(f"Table {schema}.{table} dropped.")
|
||||
conn.commit()
|
||||
except psycopg2.Error as e:
|
||||
print("Error:", e)
|
||||
finally:
|
||||
if cur:
|
||||
self.schema = connection.pop("schema_name")
|
||||
self.conn = psycopg2.connect(**connection)
|
||||
|
||||
def wipe(self, storetype: str) -> None:
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
for table in self.tables[storetype]:
|
||||
sql = f"DROP TABLE IF EXISTS {self.schema}.{table}"
|
||||
cur.execute(sql)
|
||||
print(f"Table {self.schema}.{table} dropped.")
|
||||
self.conn.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def stats(self, store_type: str) -> None:
|
||||
template = "SELECT '{table}', COUNT(*), pg_size_pretty(pg_total_relation_size('{table}')) FROM {table}"
|
||||
sql = " UNION ALL ".join(
|
||||
template.format(table=tbl) for tbl in self.tables[store_type]
|
||||
)
|
||||
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
print(f"Storage for Postgres {store_type}.")
|
||||
print("{:<15} | {:>15} | {:>9}".format("Table", "Rows", "Size"))
|
||||
print("-" * 45) # Print a line separator
|
||||
|
||||
cur.execute(sql)
|
||||
for row in cur.fetchall():
|
||||
formatted_row_count = f"{row[1]:,}"
|
||||
print(f"{row[0]:<15} | {formatted_row_count:>15} | {row[2]:>9}")
|
||||
|
||||
print()
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "conn") and self.conn:
|
||||
self.conn.close()
|
||||
|
||||
|
||||
def wipe_chroma(dbtype: str):
|
||||
assert dbtype == "vectorstore"
|
||||
wipe_tree(str((local_data_path / "chroma_db").absolute()))
|
||||
class Simple:
|
||||
def wipe(self, store_type: str) -> None:
|
||||
assert store_type == "nodestore"
|
||||
from llama_index.core.storage.docstore.types import (
|
||||
DEFAULT_PERSIST_FNAME as DOCSTORE,
|
||||
)
|
||||
from llama_index.core.storage.index_store.types import (
|
||||
DEFAULT_PERSIST_FNAME as INDEXSTORE,
|
||||
)
|
||||
|
||||
for store in (DOCSTORE, INDEXSTORE):
|
||||
wipe_file(str((local_data_path / store).absolute()))
|
||||
|
||||
|
||||
class Chroma:
|
||||
def wipe(self, store_type: str) -> None:
|
||||
assert store_type == "vectorstore"
|
||||
wipe_tree(str((local_data_path / "chroma_db").absolute()))
|
||||
|
||||
|
||||
class Qdrant:
|
||||
COLLECTION = (
|
||||
"make_this_parameterizable_per_api_call" # ?! see vector_store_component.py
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
from qdrant_client import QdrantClient # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError("Qdrant dependencies not found") from None
|
||||
self.client = QdrantClient(**settings().qdrant.model_dump(exclude_none=True))
|
||||
|
||||
def wipe(self, store_type: str) -> None:
|
||||
assert store_type == "vectorstore"
|
||||
try:
|
||||
self.client.delete_collection(self.COLLECTION)
|
||||
print("Collection dropped successfully.")
|
||||
except Exception as e:
|
||||
print("Error dropping collection:", e)
|
||||
|
||||
def stats(self, store_type: str) -> None:
|
||||
print(f"Storage for Qdrant {store_type}.")
|
||||
try:
|
||||
collection_data = self.client.get_collection(self.COLLECTION)
|
||||
if collection_data:
|
||||
# Collection Info
|
||||
# https://qdrant.tech/documentation/concepts/collections/
|
||||
print(f"\tPoints: {collection_data.points_count:,}")
|
||||
print(f"\tVectors: {collection_data.vectors_count:,}")
|
||||
print(f"\tIndex Vectors: {collection_data.indexed_vectors_count:,}")
|
||||
return
|
||||
except ValueError:
|
||||
pass
|
||||
print("\t- Qdrant collection not found or empty")
|
||||
|
||||
|
||||
class Command:
|
||||
DB_HANDLERS: ClassVar[dict[str, Any]] = {
|
||||
"simple": Simple, # node store
|
||||
"chroma": Chroma, # vector store
|
||||
"postgres": Postgres, # node, index and vector store
|
||||
"qdrant": Qdrant, # vector store
|
||||
}
|
||||
|
||||
def for_each_store(self, cmd: str):
|
||||
for store_type in ("nodestore", "vectorstore"):
|
||||
database = getattr(settings(), store_type).database
|
||||
handler_class = self.DB_HANDLERS.get(database)
|
||||
if handler_class is None:
|
||||
print(f"No handler found for database '{database}'")
|
||||
continue
|
||||
handler_instance = handler_class() # Instantiate the class
|
||||
# If the DB can handle this cmd dispatch it.
|
||||
if hasattr(handler_instance, cmd) and callable(
|
||||
func := getattr(handler_instance, cmd)
|
||||
):
|
||||
func(store_type)
|
||||
else:
|
||||
print(
|
||||
f"Unable to execute command '{cmd}' on '{store_type}' in database '{database}'"
|
||||
)
|
||||
|
||||
def execute(self, cmd: str) -> None:
|
||||
if cmd in ("wipe", "stats"):
|
||||
self.for_each_store(cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
commands = {
|
||||
"wipe": wipe,
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"mode", help="select a mode to run", choices=list(commands.keys())
|
||||
)
|
||||
parser.add_argument("mode", help="select a mode to run", choices=["wipe", "stats"])
|
||||
args = parser.parse_args()
|
||||
commands[args.mode.lower()]()
|
||||
|
||||
Command().execute(args.mode.lower())
|
||||
|
Loading…
Reference in New Issue
Block a user