mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat:chroma store refactor (#1508)
This commit is contained in:
@@ -1 +1 @@
|
|||||||
version = "0.5.5"
|
version = "0.5.6"
|
||||||
|
@@ -247,11 +247,13 @@ class KnowledgeService:
|
|||||||
doc_ids = sync_request.doc_ids
|
doc_ids = sync_request.doc_ids
|
||||||
self.model_name = sync_request.model_name or CFG.LLM_MODEL
|
self.model_name = sync_request.model_name or CFG.LLM_MODEL
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
query = KnowledgeDocumentEntity(
|
query = KnowledgeDocumentEntity(id=doc_id)
|
||||||
id=doc_id,
|
docs = knowledge_document_dao.get_documents(query)
|
||||||
space=space_name,
|
if len(docs) == 0:
|
||||||
|
raise Exception(
|
||||||
|
f"there are document called, doc_id: {sync_request.doc_id}"
|
||||||
)
|
)
|
||||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
doc = docs[0]
|
||||||
if (
|
if (
|
||||||
doc.status == SyncStatus.RUNNING.name
|
doc.status == SyncStatus.RUNNING.name
|
||||||
or doc.status == SyncStatus.FINISHED.name
|
or doc.status == SyncStatus.FINISHED.name
|
||||||
|
@@ -177,6 +177,36 @@ class VectorStoreBase(ABC):
|
|||||||
)
|
)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
def filter_by_score_threshold(
|
||||||
|
self, chunks: List[Chunk], score_threshold: float
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Filter chunks by score threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks(List[Chunks]): The chunks to filter.
|
||||||
|
score_threshold(float): The score threshold.
|
||||||
|
Return:
|
||||||
|
List[Chunks]: The filtered chunks.
|
||||||
|
"""
|
||||||
|
candidates_chunks = chunks
|
||||||
|
if score_threshold is not None:
|
||||||
|
candidates_chunks = [
|
||||||
|
Chunk(
|
||||||
|
metadata=chunk.metadata,
|
||||||
|
content=chunk.content,
|
||||||
|
score=chunk.score,
|
||||||
|
chunk_id=str(id),
|
||||||
|
)
|
||||||
|
for chunk in chunks
|
||||||
|
if chunk.score >= score_threshold
|
||||||
|
]
|
||||||
|
if len(candidates_chunks) == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No relevant docs were retrieved using the relevance score"
|
||||||
|
f" threshold {score_threshold}"
|
||||||
|
)
|
||||||
|
return candidates_chunks
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def similar_search(
|
def similar_search(
|
||||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
"""Chroma vector store."""
|
"""Chroma vector store."""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
|
||||||
|
|
||||||
from chromadb import PersistentClient
|
from chromadb import PersistentClient
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
@@ -17,6 +17,7 @@ from .filters import FilterOperator, MetadataFilters
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CHROMA_COLLECTION_NAME = "langchain"
|
||||||
|
|
||||||
@register_resource(
|
@register_resource(
|
||||||
_("Chroma Vector Store"),
|
_("Chroma Vector Store"),
|
||||||
@@ -55,9 +56,11 @@ class ChromaStore(VectorStoreBase):
|
|||||||
"""Chroma vector store."""
|
"""Chroma vector store."""
|
||||||
|
|
||||||
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
|
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
|
||||||
"""Create a ChromaStore instance."""
|
"""Create a ChromaStore instance.
|
||||||
from langchain.vectorstores import Chroma
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_store_config(ChromaVectorConfig): vector store config.
|
||||||
|
"""
|
||||||
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
||||||
chroma_path = chroma_vector_config.get(
|
chroma_path = chroma_vector_config.get(
|
||||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||||
@@ -71,31 +74,35 @@ class ChromaStore(VectorStoreBase):
|
|||||||
persist_directory=self.persist_dir,
|
persist_directory=self.persist_dir,
|
||||||
anonymized_telemetry=False,
|
anonymized_telemetry=False,
|
||||||
)
|
)
|
||||||
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
|
self._chroma_client = PersistentClient(
|
||||||
|
path=self.persist_dir, settings=chroma_settings
|
||||||
|
)
|
||||||
|
|
||||||
collection_metadata = chroma_vector_config.get("collection_metadata") or {
|
collection_metadata = chroma_vector_config.get("collection_metadata") or {
|
||||||
"hnsw:space": "cosine"
|
"hnsw:space": "cosine"
|
||||||
}
|
}
|
||||||
self.vector_store_client = Chroma(
|
self._collection = self._chroma_client.get_or_create_collection(
|
||||||
persist_directory=self.persist_dir,
|
name=CHROMA_COLLECTION_NAME,
|
||||||
embedding_function=self.embeddings,
|
embedding_function=None,
|
||||||
# client_settings=chroma_settings,
|
metadata=collection_metadata,
|
||||||
client=client,
|
)
|
||||||
collection_metadata=collection_metadata,
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
def similar_search(
|
def similar_search(
|
||||||
self, text, topk, filters: Optional[MetadataFilters] = None
|
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Search similar documents."""
|
"""Search similar documents."""
|
||||||
logger.info("ChromaStore similar search")
|
logger.info("ChromaStore similar search")
|
||||||
where_filters = self.convert_metadata_filters(filters) if filters else None
|
chroma_results = self._query(
|
||||||
lc_documents = self.vector_store_client.similarity_search(
|
text=text,
|
||||||
text, topk, filter=where_filters
|
topk=topk,
|
||||||
|
filters=filters,
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
Chunk(content=doc.page_content, metadata=doc.metadata)
|
Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
|
||||||
for doc in lc_documents
|
for chroma_result in zip(
|
||||||
|
chroma_results["documents"][0],
|
||||||
|
chroma_results["metadatas"][0],
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def similar_search_with_scores(
|
def similar_search_with_scores(
|
||||||
@@ -114,19 +121,26 @@ class ChromaStore(VectorStoreBase):
|
|||||||
filters(MetadataFilters): metadata filters, defaults to None
|
filters(MetadataFilters): metadata filters, defaults to None
|
||||||
"""
|
"""
|
||||||
logger.info("ChromaStore similar search with scores")
|
logger.info("ChromaStore similar search with scores")
|
||||||
where_filters = self.convert_metadata_filters(filters) if filters else None
|
chroma_results = self._query(
|
||||||
docs_and_scores = (
|
text=text,
|
||||||
self.vector_store_client.similarity_search_with_relevance_scores(
|
topk=topk,
|
||||||
query=text,
|
filters=filters,
|
||||||
k=topk,
|
)
|
||||||
score_threshold=score_threshold,
|
chunks = [
|
||||||
filter=where_filters,
|
(
|
||||||
|
Chunk(
|
||||||
|
content=chroma_result[0],
|
||||||
|
metadata=chroma_result[1] or {},
|
||||||
|
score=chroma_result[2],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return [
|
for chroma_result in zip(
|
||||||
Chunk(content=doc.page_content, metadata=doc.metadata, score=score)
|
chroma_results["documents"][0],
|
||||||
for doc, score in docs_and_scores
|
chroma_results["metadatas"][0],
|
||||||
|
chroma_results["distances"][0],
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
return self.filter_by_score_threshold(chunks, score_threshold)
|
||||||
|
|
||||||
def vector_name_exists(self) -> bool:
|
def vector_name_exists(self) -> bool:
|
||||||
"""Whether vector name exists."""
|
"""Whether vector name exists."""
|
||||||
@@ -138,19 +152,24 @@ class ChromaStore(VectorStoreBase):
|
|||||||
files = list(filter(lambda f: f != "chroma.sqlite3", files))
|
files = list(filter(lambda f: f != "chroma.sqlite3", files))
|
||||||
return len(files) > 0
|
return len(files) > 0
|
||||||
|
|
||||||
|
|
||||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||||
"""Load document to vector store."""
|
"""Load document to vector store."""
|
||||||
logger.info("ChromaStore load document")
|
logger.info("ChromaStore load document")
|
||||||
texts = [chunk.content for chunk in chunks]
|
texts = [chunk.content for chunk in chunks]
|
||||||
metadatas = [chunk.metadata for chunk in chunks]
|
metadatas = [chunk.metadata for chunk in chunks]
|
||||||
ids = [chunk.chunk_id for chunk in chunks]
|
ids = [chunk.chunk_id for chunk in chunks]
|
||||||
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
chroma_metadatas = [
|
||||||
|
_transform_chroma_metadata(metadata) for metadata in metadatas
|
||||||
|
]
|
||||||
|
self._add_texts(texts=texts, metadatas=chroma_metadatas, ids=ids)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def delete_vector_name(self, vector_name: str):
|
def delete_vector_name(self, vector_name: str):
|
||||||
"""Delete vector name."""
|
"""Delete vector name."""
|
||||||
logger.info(f"chroma vector_name:{vector_name} begin delete...")
|
logger.info(f"chroma vector_name:{vector_name} begin delete...")
|
||||||
self.vector_store_client.delete_collection()
|
# self.vector_store_client.delete_collection()
|
||||||
|
self._chroma_client.delete_collection(self._collection.name)
|
||||||
self._clean_persist_folder()
|
self._clean_persist_folder()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -159,8 +178,7 @@ class ChromaStore(VectorStoreBase):
|
|||||||
logger.info(f"begin delete chroma ids: {ids}")
|
logger.info(f"begin delete chroma ids: {ids}")
|
||||||
ids = ids.split(",")
|
ids = ids.split(",")
|
||||||
if len(ids) > 0:
|
if len(ids) > 0:
|
||||||
collection = self.vector_store_client._collection
|
self._collection.delete(ids=ids)
|
||||||
collection.delete(ids=ids)
|
|
||||||
|
|
||||||
def convert_metadata_filters(
|
def convert_metadata_filters(
|
||||||
self,
|
self,
|
||||||
@@ -198,6 +216,65 @@ class ChromaStore(VectorStoreBase):
|
|||||||
where_filters[chroma_condition] = filters_list
|
where_filters[chroma_condition] = filters_list
|
||||||
return where_filters
|
return where_filters
|
||||||
|
|
||||||
|
def _add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
ids: List[str],
|
||||||
|
metadatas: Optional[List[Mapping[str, Union[str, int, float, bool]]]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add texts to Chroma collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts(Iterable[str]): texts.
|
||||||
|
metadatas(Optional[List[dict]]): metadatas.
|
||||||
|
ids(Optional[List[str]]): ids.
|
||||||
|
Returns:
|
||||||
|
List[str]: ids.
|
||||||
|
"""
|
||||||
|
embeddings = None
|
||||||
|
texts = list(texts)
|
||||||
|
if self.embeddings is not None:
|
||||||
|
embeddings = self.embeddings.embed_documents(texts)
|
||||||
|
if metadatas:
|
||||||
|
try:
|
||||||
|
self._collection.upsert(
|
||||||
|
metadatas=metadatas,
|
||||||
|
embeddings=embeddings, # type: ignore
|
||||||
|
documents=texts,
|
||||||
|
ids=ids,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error upsert chromadb with metadata: {e}")
|
||||||
|
else:
|
||||||
|
self._collection.upsert(
|
||||||
|
embeddings=embeddings, # type: ignore
|
||||||
|
documents=texts,
|
||||||
|
ids=ids,
|
||||||
|
)
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def _query(self, text: str, topk: int, filters: Optional[MetadataFilters] = None):
|
||||||
|
"""Query Chroma collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text(str): query text.
|
||||||
|
topk(int): topk.
|
||||||
|
filters(MetadataFilters): metadata filters.
|
||||||
|
Returns:
|
||||||
|
dict: query result.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return {}
|
||||||
|
where_filters = self.convert_metadata_filters(filters) if filters else None
|
||||||
|
if self.embeddings is None:
|
||||||
|
raise ValueError("Chroma Embeddings is None")
|
||||||
|
query_embedding = self.embeddings.embed_query(text)
|
||||||
|
return self._collection.query(
|
||||||
|
query_embeddings=query_embedding,
|
||||||
|
n_results=topk,
|
||||||
|
where=where_filters,
|
||||||
|
)
|
||||||
|
|
||||||
def _clean_persist_folder(self):
|
def _clean_persist_folder(self):
|
||||||
"""Clean persist folder."""
|
"""Clean persist folder."""
|
||||||
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
|
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
|
||||||
@@ -230,3 +307,14 @@ def _convert_chroma_filter_operator(operator: str) -> str:
|
|||||||
return "$lte"
|
return "$lte"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Chroma Where operator {operator} not supported")
|
raise ValueError(f"Chroma Where operator {operator} not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_chroma_metadata(
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
) -> Mapping[str, str | int | float | bool]:
|
||||||
|
"""Transform metadata to Chroma metadata."""
|
||||||
|
transformed = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, (str, int, float, bool)):
|
||||||
|
transformed[key] = value
|
||||||
|
return transformed
|
||||||
|
@@ -66,7 +66,7 @@ class PGVectorStore(VectorStoreBase):
|
|||||||
embedding_function=self.embeddings,
|
embedding_function=self.embeddings,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
connection_string=self.connection_string,
|
connection_string=self.connection_string,
|
||||||
)
|
) # mypy: ignore
|
||||||
|
|
||||||
def similar_search(
|
def similar_search(
|
||||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
|
3
setup.py
3
setup.py
@@ -19,7 +19,7 @@ with open("README.md", mode="r", encoding="utf-8") as fh:
|
|||||||
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
||||||
# If you modify the version, please modify the version in the following files:
|
# If you modify the version, please modify the version in the following files:
|
||||||
# dbgpt/_version.py
|
# dbgpt/_version.py
|
||||||
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.5")
|
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.6")
|
||||||
|
|
||||||
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
||||||
LLAMA_CPP_GPU_ACCELERATION = (
|
LLAMA_CPP_GPU_ACCELERATION = (
|
||||||
@@ -499,7 +499,6 @@ def knowledge_requires():
|
|||||||
pip install "dbgpt[rag]"
|
pip install "dbgpt[rag]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [
|
setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [
|
||||||
"langchain>=0.0.286",
|
|
||||||
"spacy>=3.7",
|
"spacy>=3.7",
|
||||||
"markdown",
|
"markdown",
|
||||||
"bs4",
|
"bs4",
|
||||||
|
Reference in New Issue
Block a user