mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 04:53:36 +00:00
fix(ChatKnowledge): add aload_document (#1548)
This commit is contained in:
parent
7f55aa4b6e
commit
83d7e9d82d
@ -27,6 +27,7 @@ from dbgpt.configs.model_config import (
|
|||||||
EMBEDDING_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
)
|
)
|
||||||
|
from dbgpt.rag import ChunkParameters
|
||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||||
from dbgpt.rag.knowledge.base import ChunkStrategy
|
from dbgpt.rag.knowledge.base import ChunkStrategy
|
||||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
@ -235,13 +236,30 @@ async def document_upload(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/sync")
|
@router.post("/knowledge/{space_name}/document/sync")
|
||||||
def document_sync(space_name: str, request: DocumentSyncRequest):
|
async def document_sync(
|
||||||
|
space_name: str,
|
||||||
|
request: DocumentSyncRequest,
|
||||||
|
service: Service = Depends(get_rag_service),
|
||||||
|
):
|
||||||
logger.info(f"Received params: {space_name}, {request}")
|
logger.info(f"Received params: {space_name}, {request}")
|
||||||
try:
|
try:
|
||||||
knowledge_space_service.sync_knowledge_document(
|
space = service.get({"name": space_name})
|
||||||
space_name=space_name, sync_request=request
|
if space is None:
|
||||||
|
return Result.failed(code="E000X", msg=f"space {space_name} not exist")
|
||||||
|
if request.doc_ids is None or len(request.doc_ids) == 0:
|
||||||
|
return Result.failed(code="E000X", msg="doc_ids is None")
|
||||||
|
sync_request = KnowledgeSyncRequest(
|
||||||
|
doc_id=request.doc_ids[0],
|
||||||
|
space_id=str(space.id),
|
||||||
|
model_name=request.model_name,
|
||||||
)
|
)
|
||||||
return Result.succ([])
|
sync_request.chunk_parameters = ChunkParameters(
|
||||||
|
chunk_strategy="Automatic",
|
||||||
|
chunk_size=request.chunk_size or 512,
|
||||||
|
chunk_overlap=request.chunk_overlap or 50,
|
||||||
|
)
|
||||||
|
doc_ids = await service.sync_document(requests=[sync_request])
|
||||||
|
return Result.succ(doc_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
|
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
|
||||||
@ -32,13 +31,8 @@ from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
|||||||
from dbgpt.rag.assembler.summary import SummaryAssembler
|
from dbgpt.rag.assembler.summary import SummaryAssembler
|
||||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||||
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
|
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
from dbgpt.rag.text_splitter.text_splitter import (
|
|
||||||
RecursiveCharacterTextSplitter,
|
|
||||||
SpacyTextSplitter,
|
|
||||||
)
|
|
||||||
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
|
|
||||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||||
from dbgpt.serve.rag.service.service import SyncStatus
|
from dbgpt.serve.rag.service.service import SyncStatus
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
@ -199,186 +193,6 @@ class KnowledgeService:
|
|||||||
total = knowledge_document_dao.get_knowledge_documents_count(query)
|
total = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||||
return DocumentQueryResponse(data=data, total=total, page=page)
|
return DocumentQueryResponse(data=data, total=total, page=page)
|
||||||
|
|
||||||
def batch_document_sync(
|
|
||||||
self,
|
|
||||||
space_name,
|
|
||||||
sync_requests: List[KnowledgeSyncRequest],
|
|
||||||
) -> List[int]:
|
|
||||||
"""batch sync knowledge document chunk into vector store
|
|
||||||
Args:
|
|
||||||
- space: Knowledge Space Name
|
|
||||||
- sync_requests: List[KnowledgeSyncRequest]
|
|
||||||
Returns:
|
|
||||||
- List[int]: document ids
|
|
||||||
"""
|
|
||||||
doc_ids = []
|
|
||||||
for sync_request in sync_requests:
|
|
||||||
docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id])
|
|
||||||
if len(docs) == 0:
|
|
||||||
raise Exception(
|
|
||||||
f"there are document called, doc_id: {sync_request.doc_id}"
|
|
||||||
)
|
|
||||||
doc = docs[0]
|
|
||||||
if (
|
|
||||||
doc.status == SyncStatus.RUNNING.name
|
|
||||||
or doc.status == SyncStatus.FINISHED.name
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
|
||||||
)
|
|
||||||
chunk_parameters = sync_request.chunk_parameters
|
|
||||||
if chunk_parameters.chunk_strategy != ChunkStrategy.CHUNK_BY_SIZE.name:
|
|
||||||
space_context = self.get_space_context(space_name)
|
|
||||||
chunk_parameters.chunk_size = (
|
|
||||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
|
||||||
if space_context is None
|
|
||||||
else int(space_context["embedding"]["chunk_size"])
|
|
||||||
)
|
|
||||||
chunk_parameters.chunk_overlap = (
|
|
||||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
|
||||||
if space_context is None
|
|
||||||
else int(space_context["embedding"]["chunk_overlap"])
|
|
||||||
)
|
|
||||||
self._sync_knowledge_document(space_name, doc, chunk_parameters)
|
|
||||||
doc_ids.append(doc.id)
|
|
||||||
return doc_ids
|
|
||||||
|
|
||||||
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
|
|
||||||
"""sync knowledge document chunk into vector store
|
|
||||||
Args:
|
|
||||||
- space: Knowledge Space Name
|
|
||||||
- sync_request: DocumentSyncRequest
|
|
||||||
"""
|
|
||||||
from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter
|
|
||||||
|
|
||||||
doc_ids = sync_request.doc_ids
|
|
||||||
self.model_name = sync_request.model_name or CFG.LLM_MODEL
|
|
||||||
for doc_id in doc_ids:
|
|
||||||
query = KnowledgeDocumentEntity(id=doc_id)
|
|
||||||
docs = knowledge_document_dao.get_documents(query)
|
|
||||||
if len(docs) == 0:
|
|
||||||
raise Exception(
|
|
||||||
f"there are document called, doc_id: {sync_request.doc_id}"
|
|
||||||
)
|
|
||||||
doc = docs[0]
|
|
||||||
if (
|
|
||||||
doc.status == SyncStatus.RUNNING.name
|
|
||||||
or doc.status == SyncStatus.FINISHED.name
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
space_context = self.get_space_context(space_name)
|
|
||||||
chunk_size = (
|
|
||||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
|
||||||
if space_context is None
|
|
||||||
else int(space_context["embedding"]["chunk_size"])
|
|
||||||
)
|
|
||||||
chunk_overlap = (
|
|
||||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
|
||||||
if space_context is None
|
|
||||||
else int(space_context["embedding"]["chunk_overlap"])
|
|
||||||
)
|
|
||||||
if sync_request.chunk_size:
|
|
||||||
chunk_size = sync_request.chunk_size
|
|
||||||
if sync_request.chunk_overlap:
|
|
||||||
chunk_overlap = sync_request.chunk_overlap
|
|
||||||
separators = sync_request.separators or None
|
|
||||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
|
||||||
|
|
||||||
chunk_parameters = ChunkParameters(
|
|
||||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
||||||
)
|
|
||||||
if CFG.LANGUAGE == "en":
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
separators=separators,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
length_function=len,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if separators and len(separators) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"SpacyTextSplitter do not support multipsle separators"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
separator = "\n\n" if not separators else separators[0]
|
|
||||||
text_splitter = SpacyTextSplitter(
|
|
||||||
separator=separator,
|
|
||||||
pipeline="zh_core_web_sm",
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
separators=separators,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
)
|
|
||||||
if sync_request.pre_separator:
|
|
||||||
logger.info(f"Use preseparator, {sync_request.pre_separator}")
|
|
||||||
text_splitter = PreTextSplitter(
|
|
||||||
pre_separator=sync_request.pre_separator,
|
|
||||||
text_splitter_impl=text_splitter,
|
|
||||||
)
|
|
||||||
chunk_parameters.text_splitter = text_splitter
|
|
||||||
self._sync_knowledge_document(space_name, doc, chunk_parameters)
|
|
||||||
return doc.id
|
|
||||||
|
|
||||||
def _sync_knowledge_document(
|
|
||||||
self,
|
|
||||||
space_name,
|
|
||||||
doc: KnowledgeDocumentEntity,
|
|
||||||
chunk_parameters: ChunkParameters,
|
|
||||||
) -> List[Chunk]:
|
|
||||||
"""sync knowledge document chunk into vector store"""
|
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
|
||||||
"embedding_factory", EmbeddingFactory
|
|
||||||
)
|
|
||||||
embedding_fn = embedding_factory.create(
|
|
||||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
|
||||||
)
|
|
||||||
|
|
||||||
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
|
|
||||||
if len(spaces) != 1:
|
|
||||||
raise Exception(f"invalid space name:{space_name}")
|
|
||||||
space = spaces[0]
|
|
||||||
|
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
|
||||||
|
|
||||||
config = VectorStoreConfig(
|
|
||||||
name=space.name,
|
|
||||||
embedding_fn=embedding_fn,
|
|
||||||
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
|
|
||||||
llm_client=self.llm_client,
|
|
||||||
model_name=self.model_name,
|
|
||||||
)
|
|
||||||
vector_store_connector = VectorStoreConnector(
|
|
||||||
vector_store_type=space.vector_type, vector_store_config=config
|
|
||||||
)
|
|
||||||
knowledge = KnowledgeFactory.create(
|
|
||||||
datasource=doc.content,
|
|
||||||
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
|
|
||||||
)
|
|
||||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
|
||||||
knowledge=knowledge,
|
|
||||||
chunk_parameters=chunk_parameters,
|
|
||||||
embeddings=embedding_fn,
|
|
||||||
vector_store_connector=vector_store_connector,
|
|
||||||
)
|
|
||||||
chunk_docs = assembler.get_chunks()
|
|
||||||
doc.status = SyncStatus.RUNNING.name
|
|
||||||
doc.chunk_size = len(chunk_docs)
|
|
||||||
doc.gmt_modified = datetime.now()
|
|
||||||
knowledge_document_dao.update_knowledge_document(doc)
|
|
||||||
executor = CFG.SYSTEM_APP.get_component(
|
|
||||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
|
||||||
).create()
|
|
||||||
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
|
|
||||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
|
||||||
return chunk_docs
|
|
||||||
|
|
||||||
async def document_summary(self, request: DocumentSummaryRequest):
|
async def document_summary(self, request: DocumentSummaryRequest):
|
||||||
"""get document summary
|
"""get document summary
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
"""Embedding Assembler."""
|
"""Embedding Assembler."""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from dbgpt.core import Chunk, Embeddings
|
from dbgpt.core import Chunk, Embeddings
|
||||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
from ...util.executor_utils import blocking_func_to_async
|
||||||
from ..assembler.base import BaseAssembler
|
from ..assembler.base import BaseAssembler
|
||||||
from ..chunk_manager import ChunkParameters
|
from ..chunk_manager import ChunkParameters
|
||||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
@ -98,6 +100,41 @@ class EmbeddingAssembler(BaseAssembler):
|
|||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def aload_from_knowledge(
|
||||||
|
cls,
|
||||||
|
knowledge: Knowledge,
|
||||||
|
vector_store_connector: VectorStoreConnector,
|
||||||
|
chunk_parameters: Optional[ChunkParameters] = None,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
embeddings: Optional[Embeddings] = None,
|
||||||
|
executor: Optional[ThreadPoolExecutor] = None,
|
||||||
|
) -> "EmbeddingAssembler":
|
||||||
|
"""Load document embedding into vector store from path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
knowledge: (Knowledge) Knowledge datasource.
|
||||||
|
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
|
||||||
|
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||||
|
chunking.
|
||||||
|
embedding_model: (Optional[str]) Embedding model to use.
|
||||||
|
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||||
|
executor: (Optional[ThreadPoolExecutor) ThreadPoolExecutor to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingAssembler
|
||||||
|
"""
|
||||||
|
executor = executor or ThreadPoolExecutor()
|
||||||
|
return await blocking_func_to_async(
|
||||||
|
executor,
|
||||||
|
cls,
|
||||||
|
knowledge,
|
||||||
|
vector_store_connector,
|
||||||
|
chunk_parameters,
|
||||||
|
embedding_model,
|
||||||
|
embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
def persist(self) -> List[str]:
|
def persist(self) -> List[str]:
|
||||||
"""Persist chunks into vector store.
|
"""Persist chunks into vector store.
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||||
from dbgpt.core import Chunk, Embeddings
|
from dbgpt.core import Chunk, Embeddings
|
||||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
|
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -46,6 +47,10 @@ class IndexStoreConfig(BaseModel):
|
|||||||
class IndexStoreBase(ABC):
|
class IndexStoreBase(ABC):
|
||||||
"""Index store base class."""
|
"""Index store base class."""
|
||||||
|
|
||||||
|
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
|
||||||
|
"""Init index store."""
|
||||||
|
self._executor = executor or ThreadPoolExecutor()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||||
"""Load document in index database.
|
"""Load document in index database.
|
||||||
@ -143,6 +148,27 @@ class IndexStoreBase(ABC):
|
|||||||
)
|
)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
async def aload_document_with_limit(
|
||||||
|
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
|
||||||
|
) -> List[str]:
|
||||||
|
"""Load document in index database with specified limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks(List[Chunk]): Document chunks.
|
||||||
|
max_chunks_once_load(int): Max number of chunks to load at once.
|
||||||
|
max_threads(int): Max number of threads to use.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
List[str]: Chunk ids.
|
||||||
|
"""
|
||||||
|
return await blocking_func_to_async(
|
||||||
|
self._executor,
|
||||||
|
self.load_document_with_limit,
|
||||||
|
chunks,
|
||||||
|
max_chunks_once_load,
|
||||||
|
max_threads,
|
||||||
|
)
|
||||||
|
|
||||||
def similar_search(
|
def similar_search(
|
||||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
|
@ -443,7 +443,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
|||||||
space_id,
|
space_id,
|
||||||
doc_vo: DocumentVO,
|
doc_vo: DocumentVO,
|
||||||
chunk_parameters: ChunkParameters,
|
chunk_parameters: ChunkParameters,
|
||||||
) -> List[Chunk]:
|
) -> None:
|
||||||
"""sync knowledge document chunk into vector store"""
|
"""sync knowledge document chunk into vector store"""
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
@ -470,47 +470,45 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
|||||||
datasource=doc.content,
|
datasource=doc.content,
|
||||||
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
|
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
|
||||||
)
|
)
|
||||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
|
||||||
knowledge=knowledge,
|
|
||||||
chunk_parameters=chunk_parameters,
|
|
||||||
vector_store_connector=vector_store_connector,
|
|
||||||
)
|
|
||||||
chunk_docs = assembler.get_chunks()
|
|
||||||
doc.status = SyncStatus.RUNNING.name
|
doc.status = SyncStatus.RUNNING.name
|
||||||
doc.chunk_size = len(chunk_docs)
|
|
||||||
doc.gmt_modified = datetime.now()
|
doc.gmt_modified = datetime.now()
|
||||||
self._document_dao.update_knowledge_document(doc)
|
self._document_dao.update_knowledge_document(doc)
|
||||||
# executor = CFG.SYSTEM_APP.get_component(
|
# asyncio.create_task(self.async_doc_embedding(assembler, chunk_docs, doc))
|
||||||
# ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
asyncio.create_task(
|
||||||
# ).create()
|
self.async_doc_embedding(
|
||||||
# executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
|
knowledge, chunk_parameters, vector_store_connector, doc
|
||||||
asyncio.create_task(self.async_doc_embedding(assembler, chunk_docs, doc))
|
)
|
||||||
|
)
|
||||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||||
return chunk_docs
|
# return chunk_docs
|
||||||
|
|
||||||
@trace("async_doc_embedding")
|
@trace("async_doc_embedding")
|
||||||
async def async_doc_embedding(self, assembler, chunk_docs, doc):
|
async def async_doc_embedding(
|
||||||
|
self, knowledge, chunk_parameters, vector_store_connector, doc
|
||||||
|
):
|
||||||
"""async document embedding into vector db
|
"""async document embedding into vector db
|
||||||
Args:
|
Args:
|
||||||
- client: EmbeddingEngine Client
|
- knowledge: Knowledge
|
||||||
- chunk_docs: List[Document]
|
- chunk_parameters: ChunkParameters
|
||||||
- doc: KnowledgeDocumentEntity
|
- vector_store_connector: vector_store_connector
|
||||||
|
- doc: doc
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"async doc embedding sync, doc:{doc.doc_name}")
|
||||||
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
with root_tracer.start_span(
|
with root_tracer.start_span(
|
||||||
"app.knowledge.assembler.persist",
|
"app.knowledge.assembler.persist",
|
||||||
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)},
|
metadata={"doc": doc.doc_name},
|
||||||
):
|
):
|
||||||
# vector_ids = assembler.persist()
|
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
||||||
space = self.get({"name": doc.space})
|
knowledge=knowledge,
|
||||||
if space and space.vector_type == "KnowledgeGraph":
|
chunk_parameters=chunk_parameters,
|
||||||
vector_ids = await assembler.apersist()
|
vector_store_connector=vector_store_connector,
|
||||||
else:
|
)
|
||||||
vector_ids = assembler.persist()
|
chunk_docs = assembler.get_chunks()
|
||||||
|
doc.chunk_size = len(chunk_docs)
|
||||||
|
vector_ids = await assembler.apersist()
|
||||||
doc.status = SyncStatus.FINISHED.name
|
doc.status = SyncStatus.FINISHED.name
|
||||||
doc.result = "document embedding success"
|
doc.result = "document embedding success"
|
||||||
if vector_ids is not None:
|
if vector_ids is not None:
|
||||||
|
@ -37,7 +37,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
|||||||
def __init__(self, config: BuiltinKnowledgeGraphConfig):
|
def __init__(self, config: BuiltinKnowledgeGraphConfig):
|
||||||
"""Create builtin knowledge graph instance."""
|
"""Create builtin knowledge graph instance."""
|
||||||
self._config = config
|
self._config = config
|
||||||
|
super().__init__()
|
||||||
self._llm_client = config.llm_client
|
self._llm_client = config.llm_client
|
||||||
if not self._llm_client:
|
if not self._llm_client:
|
||||||
raise ValueError("No llm client provided.")
|
raise ValueError("No llm client provided.")
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from dbgpt._private.pydantic import ConfigDict, Field
|
from dbgpt._private.pydantic import ConfigDict, Field
|
||||||
@ -9,6 +10,7 @@ from dbgpt.core import Chunk, Embeddings
|
|||||||
from dbgpt.core.awel.flow import Parameter
|
from dbgpt.core.awel.flow import Parameter
|
||||||
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
|
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
|
||||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
|
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -102,6 +104,10 @@ class VectorStoreConfig(IndexStoreConfig):
|
|||||||
class VectorStoreBase(IndexStoreBase, ABC):
|
class VectorStoreBase(IndexStoreBase, ABC):
|
||||||
"""Vector store base class."""
|
"""Vector store base class."""
|
||||||
|
|
||||||
|
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
|
||||||
|
"""Initialize vector store."""
|
||||||
|
super().__init__(executor)
|
||||||
|
|
||||||
def filter_by_score_threshold(
|
def filter_by_score_threshold(
|
||||||
self, chunks: List[Chunk], score_threshold: float
|
self, chunks: List[Chunk], score_threshold: float
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
@ -160,7 +166,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
|||||||
return 1.0 - distance / math.sqrt(2)
|
return 1.0 - distance / math.sqrt(2)
|
||||||
|
|
||||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore
|
async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore
|
||||||
"""Load document in index database.
|
"""Async load document in index database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunks(List[Chunk]): document chunks.
|
chunks(List[Chunk]): document chunks.
|
||||||
@ -168,4 +174,4 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
|||||||
Return:
|
Return:
|
||||||
List[str]: chunk ids.
|
List[str]: chunk ids.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
return await blocking_func_to_async(self._executor, self.load_document, chunks)
|
||||||
|
@ -62,6 +62,7 @@ class ChromaStore(VectorStoreBase):
|
|||||||
Args:
|
Args:
|
||||||
vector_store_config(ChromaVectorConfig): vector store config.
|
vector_store_config(ChromaVectorConfig): vector store config.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
||||||
chroma_path = chroma_vector_config.get(
|
chroma_path = chroma_vector_config.get(
|
||||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||||
|
@ -170,14 +170,22 @@ class VectorStoreConnector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||||
"""Load document in vector database.
|
"""Async load document in vector database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- chunks: document chunks.
|
- chunks: document chunks.
|
||||||
Return chunk ids.
|
Return chunk ids.
|
||||||
"""
|
"""
|
||||||
return await self.client.aload_document(
|
max_chunks_once_load = (
|
||||||
chunks,
|
self._index_store_config.max_chunks_once_load
|
||||||
|
if self._index_store_config
|
||||||
|
else 10
|
||||||
|
)
|
||||||
|
max_threads = (
|
||||||
|
self._index_store_config.max_threads if self._index_store_config else 1
|
||||||
|
)
|
||||||
|
return await self.client.aload_document_with_limit(
|
||||||
|
chunks, max_chunks_once_load, max_threads
|
||||||
)
|
)
|
||||||
|
|
||||||
def similar_search(
|
def similar_search(
|
||||||
|
@ -125,6 +125,7 @@ class ElasticStore(VectorStoreBase):
|
|||||||
Args:
|
Args:
|
||||||
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
|
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
connect_kwargs = {}
|
connect_kwargs = {}
|
||||||
elasticsearch_vector_config = vector_store_config.dict()
|
elasticsearch_vector_config = vector_store_config.dict()
|
||||||
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
|
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
|
||||||
|
@ -149,8 +149,14 @@ class MilvusStore(VectorStoreBase):
|
|||||||
vector_store_config (MilvusVectorConfig): MilvusStore config.
|
vector_store_config (MilvusVectorConfig): MilvusStore config.
|
||||||
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
|
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
|
||||||
"""
|
"""
|
||||||
from pymilvus import connections
|
super().__init__()
|
||||||
|
try:
|
||||||
|
from pymilvus import connections
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
connect_kwargs = {}
|
connect_kwargs = {}
|
||||||
milvus_vector_config = vector_store_config.to_dict()
|
milvus_vector_config = vector_store_config.to_dict()
|
||||||
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
||||||
@ -373,8 +379,13 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self, text, topk, filters: Optional[MetadataFilters] = None
|
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Perform a search on a query string and return results."""
|
"""Perform a search on a query string and return results."""
|
||||||
from pymilvus import Collection, DataType
|
try:
|
||||||
|
from pymilvus import Collection, DataType
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
"""similar_search in vector database."""
|
"""similar_search in vector database."""
|
||||||
self.col = Collection(self.collection_name)
|
self.col = Collection(self.collection_name)
|
||||||
schema = self.col.schema
|
schema = self.col.schema
|
||||||
@ -419,7 +430,13 @@ class MilvusStore(VectorStoreBase):
|
|||||||
Returns:
|
Returns:
|
||||||
List[Tuple[Document, float]]: Result doc and score.
|
List[Tuple[Document, float]]: Result doc and score.
|
||||||
"""
|
"""
|
||||||
from pymilvus import Collection
|
try:
|
||||||
|
from pymilvus import Collection, DataType
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
|
|
||||||
self.col = Collection(self.collection_name)
|
self.col = Collection(self.collection_name)
|
||||||
schema = self.col.schema
|
schema = self.col.schema
|
||||||
@ -429,7 +446,6 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.fields.remove(x.name)
|
self.fields.remove(x.name)
|
||||||
if x.is_primary:
|
if x.is_primary:
|
||||||
self.primary_field = x.name
|
self.primary_field = x.name
|
||||||
from pymilvus import DataType
|
|
||||||
|
|
||||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
self.vector_field = x.name
|
self.vector_field = x.name
|
||||||
@ -526,15 +542,26 @@ class MilvusStore(VectorStoreBase):
|
|||||||
|
|
||||||
def vector_name_exists(self):
|
def vector_name_exists(self):
|
||||||
"""Whether vector name exists."""
|
"""Whether vector name exists."""
|
||||||
from pymilvus import utility
|
try:
|
||||||
|
from pymilvus import utility
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
|
|
||||||
"""is vector store name exist."""
|
"""is vector store name exist."""
|
||||||
return utility.has_collection(self.collection_name)
|
return utility.has_collection(self.collection_name)
|
||||||
|
|
||||||
def delete_vector_name(self, vector_name: str):
|
def delete_vector_name(self, vector_name: str):
|
||||||
"""Delete vector name."""
|
"""Delete vector name."""
|
||||||
from pymilvus import utility
|
try:
|
||||||
|
from pymilvus import utility
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
"""milvus delete collection name"""
|
"""milvus delete collection name"""
|
||||||
logger.info(f"milvus vector_name:{vector_name} begin delete...")
|
logger.info(f"milvus vector_name:{vector_name} begin delete...")
|
||||||
utility.drop_collection(self.collection_name)
|
utility.drop_collection(self.collection_name)
|
||||||
@ -542,8 +569,13 @@ class MilvusStore(VectorStoreBase):
|
|||||||
|
|
||||||
def delete_by_ids(self, ids):
|
def delete_by_ids(self, ids):
|
||||||
"""Delete vector by ids."""
|
"""Delete vector by ids."""
|
||||||
from pymilvus import Collection
|
try:
|
||||||
|
from pymilvus import Collection
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
self.col = Collection(self.collection_name)
|
self.col = Collection(self.collection_name)
|
||||||
# milvus delete vectors by ids
|
# milvus delete vectors by ids
|
||||||
logger.info(f"begin delete milvus ids: {ids}")
|
logger.info(f"begin delete milvus ids: {ids}")
|
||||||
|
@ -717,7 +717,7 @@ class OceanBaseStore(VectorStoreBase):
|
|||||||
"""Create a OceanBaseStore instance."""
|
"""Create a OceanBaseStore instance."""
|
||||||
if vector_store_config.embedding_fn is None:
|
if vector_store_config.embedding_fn is None:
|
||||||
raise ValueError("embedding_fn is required for OceanBaseStore")
|
raise ValueError("embedding_fn is required for OceanBaseStore")
|
||||||
|
super().__init__()
|
||||||
self.embeddings = vector_store_config.embedding_fn
|
self.embeddings = vector_store_config.embedding_fn
|
||||||
self.collection_name = vector_store_config.name
|
self.collection_name = vector_store_config.name
|
||||||
vector_store_config = vector_store_config.dict()
|
vector_store_config = vector_store_config.dict()
|
||||||
|
@ -63,6 +63,7 @@ class PGVectorStore(VectorStoreBase):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install the `langchain` package to use the PGVector."
|
"Please install the `langchain` package to use the PGVector."
|
||||||
)
|
)
|
||||||
|
super().__init__()
|
||||||
self.connection_string = vector_store_config.connection_string
|
self.connection_string = vector_store_config.connection_string
|
||||||
self.embeddings = vector_store_config.embedding_fn
|
self.embeddings = vector_store_config.embedding_fn
|
||||||
self.collection_name = vector_store_config.name
|
self.collection_name = vector_store_config.name
|
||||||
|
@ -68,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
|
|||||||
"Could not import weaviate python package. "
|
"Could not import weaviate python package. "
|
||||||
"Please install it with `pip install weaviate-client`."
|
"Please install it with `pip install weaviate-client`."
|
||||||
)
|
)
|
||||||
|
super().__init__()
|
||||||
self.weaviate_url = vector_store_config.weaviate_url
|
self.weaviate_url = vector_store_config.weaviate_url
|
||||||
self.embedding = vector_store_config.embedding_fn
|
self.embedding = vector_store_config.embedding_fn
|
||||||
self.vector_name = vector_store_config.name
|
self.vector_name = vector_store_config.name
|
||||||
|
Loading…
Reference in New Issue
Block a user