feat:chroma store refactor (#1508)

This commit is contained in:
Aries-ckt
2024-05-11 16:31:34 +08:00
committed by GitHub
parent bc9ce3c2ae
commit d3131552d3
6 changed files with 159 additions and 40 deletions

View File

@@ -1 +1 @@
version = "0.5.5" version = "0.5.6"

View File

@@ -247,11 +247,13 @@ class KnowledgeService:
doc_ids = sync_request.doc_ids doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids: for doc_id in doc_ids:
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(id=doc_id)
id=doc_id, docs = knowledge_document_dao.get_documents(query)
space=space_name, if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
) )
doc = knowledge_document_dao.get_knowledge_documents(query)[0] doc = docs[0]
if ( if (
doc.status == SyncStatus.RUNNING.name doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name or doc.status == SyncStatus.FINISHED.name

View File

@@ -177,6 +177,36 @@ class VectorStoreBase(ABC):
) )
return ids return ids
def filter_by_score_threshold(
self, chunks: List[Chunk], score_threshold: float
) -> List[Chunk]:
"""Filter chunks by score threshold.
Args:
chunks(List[Chunks]): The chunks to filter.
score_threshold(float): The score threshold.
Return:
List[Chunks]: The filtered chunks.
"""
candidates_chunks = chunks
if score_threshold is not None:
candidates_chunks = [
Chunk(
metadata=chunk.metadata,
content=chunk.content,
score=chunk.score,
chunk_id=str(id),
)
for chunk in chunks
if chunk.score >= score_threshold
]
if len(candidates_chunks) == 0:
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return candidates_chunks
@abstractmethod @abstractmethod
def similar_search( def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None self, text: str, topk: int, filters: Optional[MetadataFilters] = None

View File

@@ -1,7 +1,7 @@
"""Chroma vector store.""" """Chroma vector store."""
import logging import logging
import os import os
from typing import List, Optional from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from chromadb import PersistentClient from chromadb import PersistentClient
from chromadb.config import Settings from chromadb.config import Settings
@@ -17,6 +17,7 @@ from .filters import FilterOperator, MetadataFilters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CHROMA_COLLECTION_NAME = "langchain"
@register_resource( @register_resource(
_("Chroma Vector Store"), _("Chroma Vector Store"),
@@ -55,9 +56,11 @@ class ChromaStore(VectorStoreBase):
"""Chroma vector store.""" """Chroma vector store."""
def __init__(self, vector_store_config: ChromaVectorConfig) -> None: def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
"""Create a ChromaStore instance.""" """Create a ChromaStore instance.
from langchain.vectorstores import Chroma
Args:
vector_store_config(ChromaVectorConfig): vector store config.
"""
chroma_vector_config = vector_store_config.to_dict(exclude_none=True) chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_path = chroma_vector_config.get( chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data") "persist_path", os.path.join(PILOT_PATH, "data")
@@ -71,31 +74,35 @@ class ChromaStore(VectorStoreBase):
persist_directory=self.persist_dir, persist_directory=self.persist_dir,
anonymized_telemetry=False, anonymized_telemetry=False,
) )
client = PersistentClient(path=self.persist_dir, settings=chroma_settings) self._chroma_client = PersistentClient(
path=self.persist_dir, settings=chroma_settings
)
collection_metadata = chroma_vector_config.get("collection_metadata") or { collection_metadata = chroma_vector_config.get("collection_metadata") or {
"hnsw:space": "cosine" "hnsw:space": "cosine"
} }
self.vector_store_client = Chroma( self._collection = self._chroma_client.get_or_create_collection(
persist_directory=self.persist_dir, name=CHROMA_COLLECTION_NAME,
embedding_function=self.embeddings, embedding_function=None,
# client_settings=chroma_settings, metadata=collection_metadata,
client=client, )
collection_metadata=collection_metadata,
) # type: ignore
def similar_search( def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:
"""Search similar documents.""" """Search similar documents."""
logger.info("ChromaStore similar search") logger.info("ChromaStore similar search")
where_filters = self.convert_metadata_filters(filters) if filters else None chroma_results = self._query(
lc_documents = self.vector_store_client.similarity_search( text=text,
text, topk, filter=where_filters topk=topk,
filters=filters,
) )
return [ return [
Chunk(content=doc.page_content, metadata=doc.metadata) Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
for doc in lc_documents for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
)
] ]
def similar_search_with_scores( def similar_search_with_scores(
@@ -114,19 +121,26 @@ class ChromaStore(VectorStoreBase):
filters(MetadataFilters): metadata filters, defaults to None filters(MetadataFilters): metadata filters, defaults to None
""" """
logger.info("ChromaStore similar search with scores") logger.info("ChromaStore similar search with scores")
where_filters = self.convert_metadata_filters(filters) if filters else None chroma_results = self._query(
docs_and_scores = ( text=text,
self.vector_store_client.similarity_search_with_relevance_scores( topk=topk,
query=text, filters=filters,
k=topk, )
score_threshold=score_threshold, chunks = [
filter=where_filters, (
Chunk(
content=chroma_result[0],
metadata=chroma_result[1] or {},
score=chroma_result[2],
) )
) )
return [ for chroma_result in zip(
Chunk(content=doc.page_content, metadata=doc.metadata, score=score) chroma_results["documents"][0],
for doc, score in docs_and_scores chroma_results["metadatas"][0],
chroma_results["distances"][0],
)
] ]
return self.filter_by_score_threshold(chunks, score_threshold)
def vector_name_exists(self) -> bool: def vector_name_exists(self) -> bool:
"""Whether vector name exists.""" """Whether vector name exists."""
@@ -138,19 +152,24 @@ class ChromaStore(VectorStoreBase):
files = list(filter(lambda f: f != "chroma.sqlite3", files)) files = list(filter(lambda f: f != "chroma.sqlite3", files))
return len(files) > 0 return len(files) > 0
def load_document(self, chunks: List[Chunk]) -> List[str]: def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document to vector store.""" """Load document to vector store."""
logger.info("ChromaStore load document") logger.info("ChromaStore load document")
texts = [chunk.content for chunk in chunks] texts = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks] metadatas = [chunk.metadata for chunk in chunks]
ids = [chunk.chunk_id for chunk in chunks] ids = [chunk.chunk_id for chunk in chunks]
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids) chroma_metadatas = [
_transform_chroma_metadata(metadata) for metadata in metadatas
]
self._add_texts(texts=texts, metadatas=chroma_metadatas, ids=ids)
return ids return ids
def delete_vector_name(self, vector_name: str): def delete_vector_name(self, vector_name: str):
"""Delete vector name.""" """Delete vector name."""
logger.info(f"chroma vector_name:{vector_name} begin delete...") logger.info(f"chroma vector_name:{vector_name} begin delete...")
self.vector_store_client.delete_collection() # self.vector_store_client.delete_collection()
self._chroma_client.delete_collection(self._collection.name)
self._clean_persist_folder() self._clean_persist_folder()
return True return True
@@ -159,8 +178,7 @@ class ChromaStore(VectorStoreBase):
logger.info(f"begin delete chroma ids: {ids}") logger.info(f"begin delete chroma ids: {ids}")
ids = ids.split(",") ids = ids.split(",")
if len(ids) > 0: if len(ids) > 0:
collection = self.vector_store_client._collection self._collection.delete(ids=ids)
collection.delete(ids=ids)
def convert_metadata_filters( def convert_metadata_filters(
self, self,
@@ -198,6 +216,65 @@ class ChromaStore(VectorStoreBase):
where_filters[chroma_condition] = filters_list where_filters[chroma_condition] = filters_list
return where_filters return where_filters
def _add_texts(
self,
texts: Iterable[str],
ids: List[str],
metadatas: Optional[List[Mapping[str, Union[str, int, float, bool]]]] = None,
) -> List[str]:
"""Add texts to Chroma collection.
Args:
texts(Iterable[str]): texts.
metadatas(Optional[List[dict]]): metadatas.
ids(Optional[List[str]]): ids.
Returns:
List[str]: ids.
"""
embeddings = None
texts = list(texts)
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(texts)
if metadatas:
try:
self._collection.upsert(
metadatas=metadatas,
embeddings=embeddings, # type: ignore
documents=texts,
ids=ids,
)
except ValueError as e:
logger.error(f"Error upsert chromadb with metadata: {e}")
else:
self._collection.upsert(
embeddings=embeddings, # type: ignore
documents=texts,
ids=ids,
)
return ids
def _query(self, text: str, topk: int, filters: Optional[MetadataFilters] = None):
"""Query Chroma collection.
Args:
text(str): query text.
topk(int): topk.
filters(MetadataFilters): metadata filters.
Returns:
dict: query result.
"""
if not text:
return {}
where_filters = self.convert_metadata_filters(filters) if filters else None
if self.embeddings is None:
raise ValueError("Chroma Embeddings is None")
query_embedding = self.embeddings.embed_query(text)
return self._collection.query(
query_embeddings=query_embedding,
n_results=topk,
where=where_filters,
)
def _clean_persist_folder(self): def _clean_persist_folder(self):
"""Clean persist folder.""" """Clean persist folder."""
for root, dirs, files in os.walk(self.persist_dir, topdown=False): for root, dirs, files in os.walk(self.persist_dir, topdown=False):
@@ -230,3 +307,14 @@ def _convert_chroma_filter_operator(operator: str) -> str:
return "$lte" return "$lte"
else: else:
raise ValueError(f"Chroma Where operator {operator} not supported") raise ValueError(f"Chroma Where operator {operator} not supported")
def _transform_chroma_metadata(
metadata: Dict[str, Any]
) -> Mapping[str, str | int | float | bool]:
"""Transform metadata to Chroma metadata."""
transformed = {}
for key, value in metadata.items():
if isinstance(value, (str, int, float, bool)):
transformed[key] = value
return transformed

View File

@@ -66,7 +66,7 @@ class PGVectorStore(VectorStoreBase):
embedding_function=self.embeddings, embedding_function=self.embeddings,
collection_name=self.collection_name, collection_name=self.collection_name,
connection_string=self.connection_string, connection_string=self.connection_string,
) ) # mypy: ignore
def similar_search( def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None self, text: str, topk: int, filters: Optional[MetadataFilters] = None

View File

@@ -19,7 +19,7 @@ with open("README.md", mode="r", encoding="utf-8") as fh:
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true" IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
# If you modify the version, please modify the version in the following files: # If you modify the version, please modify the version in the following files:
# dbgpt/_version.py # dbgpt/_version.py
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.5") DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.6")
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true" BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = ( LLAMA_CPP_GPU_ACCELERATION = (
@@ -499,7 +499,6 @@ def knowledge_requires():
pip install "dbgpt[rag]" pip install "dbgpt[rag]"
""" """
setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [ setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [
"langchain>=0.0.286",
"spacy>=3.7", "spacy>=3.7",
"markdown", "markdown",
"bs4", "bs4",