mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
fix(core): Delete corresponding Milvus data when using with TuGraph (#2858)
This commit is contained in:
@@ -464,6 +464,23 @@ class KnowledgeService:
|
|||||||
)
|
)
|
||||||
# delete vector by ids
|
# delete vector by ids
|
||||||
storage_connector.delete_by_ids(vector_ids)
|
storage_connector.delete_by_ids(vector_ids)
|
||||||
|
|
||||||
|
# we next delete the corresponding CHUNK HISTORY data in Milvus
|
||||||
|
if (
|
||||||
|
space.vector_type == "KnowledgeGraph"
|
||||||
|
and storage_connector._vector_store_config.__type__ == "milvus"
|
||||||
|
):
|
||||||
|
# this gives the vector store type
|
||||||
|
# in case this will support chroma in the future
|
||||||
|
embedding_vector_type = storage_connector._vector_store_config.__type__
|
||||||
|
# get the collection name
|
||||||
|
space_name = space_name + "_CHUNK_HISTORY"
|
||||||
|
storage_connector = self.storage_manager.get_storage_connector(
|
||||||
|
index_name=space_name, storage_type=embedding_vector_type
|
||||||
|
)
|
||||||
|
# documents[0].id is the id we use to find the corresponding document
|
||||||
|
storage_connector.delete_by_file_id(documents[0].id)
|
||||||
|
|
||||||
# delete chunks
|
# delete chunks
|
||||||
document_chunk_dao.raw_delete(documents[0].id)
|
document_chunk_dao.raw_delete(documents[0].id)
|
||||||
# delete document
|
# delete document
|
||||||
|
@@ -158,6 +158,7 @@ class IndexStoreBase(ABC):
|
|||||||
chunks: List[Chunk],
|
chunks: List[Chunk],
|
||||||
max_chunks_once_load: Optional[int] = None,
|
max_chunks_once_load: Optional[int] = None,
|
||||||
max_threads: Optional[int] = None,
|
max_threads: Optional[int] = None,
|
||||||
|
file_id: Optional[str] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Load document in index database with specified limit.
|
"""Load document in index database with specified limit.
|
||||||
|
|
||||||
@@ -171,6 +172,7 @@ class IndexStoreBase(ABC):
|
|||||||
"""
|
"""
|
||||||
max_chunks_once_load = max_chunks_once_load or self._max_chunks_once_load
|
max_chunks_once_load = max_chunks_once_load or self._max_chunks_once_load
|
||||||
max_threads = max_threads or self._max_threads
|
max_threads = max_threads or self._max_threads
|
||||||
|
file_id = file_id or None
|
||||||
chunk_groups = [
|
chunk_groups = [
|
||||||
chunks[i : i + max_chunks_once_load]
|
chunks[i : i + max_chunks_once_load]
|
||||||
for i in range(0, len(chunks), max_chunks_once_load)
|
for i in range(0, len(chunks), max_chunks_once_load)
|
||||||
@@ -181,7 +183,7 @@ class IndexStoreBase(ABC):
|
|||||||
)
|
)
|
||||||
tasks = []
|
tasks = []
|
||||||
for chunk_group in chunk_groups:
|
for chunk_group in chunk_groups:
|
||||||
tasks.append(self.aload_document(chunk_group))
|
tasks.append(self.aload_document(chunk_group, file_id))
|
||||||
|
|
||||||
results = await self._run_tasks_with_concurrency(tasks, max_threads)
|
results = await self._run_tasks_with_concurrency(tasks, max_threads)
|
||||||
|
|
||||||
|
@@ -187,7 +187,9 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
|||||||
"""Return a similarity score on a scale [0, 1]."""
|
"""Return a similarity score on a scale [0, 1]."""
|
||||||
return 1.0 - distance / math.sqrt(2)
|
return 1.0 - distance / math.sqrt(2)
|
||||||
|
|
||||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore
|
async def aload_document(
|
||||||
|
self, chunks: List[Chunk], file_id: Optional[str] = None
|
||||||
|
) -> List[str]: # type: ignore
|
||||||
"""Async load document in index database.
|
"""Async load document in index database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -146,8 +146,9 @@ class EmbeddingAssembler(BaseAssembler):
|
|||||||
# persist chunks into vector store
|
# persist chunks into vector store
|
||||||
max_chunks_once_load = kwargs.get("max_chunks_once_load")
|
max_chunks_once_load = kwargs.get("max_chunks_once_load")
|
||||||
max_threads = kwargs.get("max_threads")
|
max_threads = kwargs.get("max_threads")
|
||||||
|
file_id = kwargs.get("file_id", None)
|
||||||
return await self._index_store.aload_document_with_limit(
|
return await self._index_store.aload_document_with_limit(
|
||||||
self._chunks, max_chunks_once_load, max_threads
|
self._chunks, max_chunks_once_load, max_threads, file_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_info(self, chunks) -> List[Chunk]:
|
def _extract_info(self, chunks) -> List[Chunk]:
|
||||||
|
@@ -39,7 +39,9 @@ class GraphExtractor(LLMExtractor):
|
|||||||
self._topk = top_k
|
self._topk = top_k
|
||||||
self._score_threshold = score_threshold
|
self._score_threshold = score_threshold
|
||||||
|
|
||||||
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
|
async def aload_chunk_context(
|
||||||
|
self, texts: List[str], file_id: Optional[str] = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
"""Load chunk context."""
|
"""Load chunk context."""
|
||||||
text_context_map: Dict[str, str] = {}
|
text_context_map: Dict[str, str] = {}
|
||||||
|
|
||||||
@@ -53,8 +55,14 @@ class GraphExtractor(LLMExtractor):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Save chunk to history
|
# Save chunk to history
|
||||||
|
# here we save the file_id into the metadata
|
||||||
await self._chunk_history.aload_document_with_limit(
|
await self._chunk_history.aload_document_with_limit(
|
||||||
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
|
[
|
||||||
|
Chunk(
|
||||||
|
content=text,
|
||||||
|
metadata={"relevant_cnt": len(history), "file_id": file_id},
|
||||||
|
)
|
||||||
|
],
|
||||||
self._max_chunks_once_load,
|
self._max_chunks_once_load,
|
||||||
self._max_threads,
|
self._max_threads,
|
||||||
)
|
)
|
||||||
@@ -81,6 +89,7 @@ class GraphExtractor(LLMExtractor):
|
|||||||
texts: List[str],
|
texts: List[str],
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
|
file_id: Optional[str] = None,
|
||||||
) -> Optional[List[List[Graph]]]:
|
) -> Optional[List[List[Graph]]]:
|
||||||
"""Extract graphs from chunks in batches.
|
"""Extract graphs from chunks in batches.
|
||||||
|
|
||||||
@@ -90,7 +99,7 @@ class GraphExtractor(LLMExtractor):
|
|||||||
raise ValueError("batch_size >= 1")
|
raise ValueError("batch_size >= 1")
|
||||||
|
|
||||||
# 1. Load chunk context
|
# 1. Load chunk context
|
||||||
text_context_map = await self.aload_chunk_context(texts)
|
text_context_map = await self.aload_chunk_context(texts, file_id)
|
||||||
|
|
||||||
# Pre-allocate results list to maintain order
|
# Pre-allocate results list to maintain order
|
||||||
graphs_list: List[List[Graph]] = [None] * len(texts)
|
graphs_list: List[List[Graph]] = [None] * len(texts)
|
||||||
|
@@ -309,12 +309,14 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
"""Get the knowledge graph config."""
|
"""Get the knowledge graph config."""
|
||||||
return self._embedding_fn
|
return self._embedding_fn
|
||||||
|
|
||||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
async def aload_document(
|
||||||
|
self, chunks: List[Chunk], file_id: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
"""Extract and persist graph from the document file."""
|
"""Extract and persist graph from the document file."""
|
||||||
if not self.vector_name_exists():
|
if not self.vector_name_exists():
|
||||||
self._graph_store_adapter.create_graph(self._graph_name)
|
self._graph_store_adapter.create_graph(self._graph_name)
|
||||||
await self._aload_document_graph(chunks)
|
await self._aload_document_graph(chunks)
|
||||||
await self._aload_triplet_graph(chunks)
|
await self._aload_triplet_graph(chunks, file_id)
|
||||||
await self._community_store.build_communities(
|
await self._community_store.build_communities(
|
||||||
batch_size=self._community_summary_batch_size
|
batch_size=self._community_summary_batch_size
|
||||||
)
|
)
|
||||||
@@ -364,7 +366,9 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
chunk=paragraph_chunks[chunk_index - 1], next_chunk=chunk
|
chunk=paragraph_chunks[chunk_index - 1], next_chunk=chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
|
async def _aload_triplet_graph(
|
||||||
|
self, chunks: List[Chunk], file_id: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
"""Load the knowledge graph from the chunks.
|
"""Load the knowledge graph from the chunks.
|
||||||
|
|
||||||
The chunks include the doc structure.
|
The chunks include the doc structure.
|
||||||
@@ -379,6 +383,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
|||||||
graphs_list = await self._graph_extractor.batch_extract(
|
graphs_list = await self._graph_extractor.batch_extract(
|
||||||
[chunk.content for chunk in chunks],
|
[chunk.content for chunk in chunks],
|
||||||
batch_size=self._triplet_extraction_batch_size,
|
batch_size=self._triplet_extraction_batch_size,
|
||||||
|
file_id=file_id,
|
||||||
)
|
)
|
||||||
if not graphs_list:
|
if not graphs_list:
|
||||||
raise ValueError("No graphs extracted from the chunks")
|
raise ValueError("No graphs extracted from the chunks")
|
||||||
|
@@ -693,6 +693,25 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.col.delete(delete_expr)
|
self.col.delete(delete_expr)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# delete the corresponding vectors by file_id
|
||||||
|
def delete_by_file_id(self, file_id: str):
|
||||||
|
print("MilvusStore.delete_by_file_id")
|
||||||
|
"""Delete vector by file_id."""
|
||||||
|
try:
|
||||||
|
from pymilvus import Collection
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
|
self.col = Collection(self.collection_name)
|
||||||
|
print(self.col)
|
||||||
|
# milvus delete vectors by file_id
|
||||||
|
logger.info(f"begin delete milvus file_id: {file_id}")
|
||||||
|
delete_expr = f"{self.metadata_field} like '%\"file_id\": {file_id}%'"
|
||||||
|
self.col.delete(delete_expr)
|
||||||
|
return True
|
||||||
|
|
||||||
def convert_metadata_filters(self, filters: MetadataFilters) -> str:
|
def convert_metadata_filters(self, filters: MetadataFilters) -> str:
|
||||||
"""Convert filter to milvus filters.
|
"""Convert filter to milvus filters.
|
||||||
|
|
||||||
|
@@ -250,6 +250,14 @@ class VectorStoreConnector:
|
|||||||
"""
|
"""
|
||||||
return self.client.delete_by_ids(ids=ids)
|
return self.client.delete_by_ids(ids=ids)
|
||||||
|
|
||||||
|
def delete_by_file_id(self, file_id):
|
||||||
|
"""Delete file by ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- ids: vector ids
|
||||||
|
"""
|
||||||
|
return self.client.delete_by_file_id(file_id=file_id)
|
||||||
|
|
||||||
def truncate(self):
|
def truncate(self):
|
||||||
"""Truncate data."""
|
"""Truncate data."""
|
||||||
return self.client.truncate()
|
return self.client.truncate()
|
||||||
|
@@ -602,9 +602,11 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
|||||||
|
|
||||||
chunk_docs = assembler.get_chunks()
|
chunk_docs = assembler.get_chunks()
|
||||||
doc.chunk_size = len(chunk_docs)
|
doc.chunk_size = len(chunk_docs)
|
||||||
|
# this will be the start point where file_id is added
|
||||||
vector_ids = await assembler.apersist(
|
vector_ids = await assembler.apersist(
|
||||||
max_chunks_once_load=max_chunks_once_load,
|
max_chunks_once_load=max_chunks_once_load,
|
||||||
max_threads=max_threads,
|
max_threads=max_threads,
|
||||||
|
file_id=doc.id,
|
||||||
)
|
)
|
||||||
doc.status = SyncStatus.FINISHED.name
|
doc.status = SyncStatus.FINISHED.name
|
||||||
doc.result = "document persist into index store success"
|
doc.result = "document persist into index store success"
|
||||||
|
Reference in New Issue
Block a user