mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
fix(ChatKnowledge): add aload_document (#1548)
This commit is contained in:
@@ -27,6 +27,7 @@ from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
@@ -235,13 +236,30 @@ async def document_upload(
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync")
|
||||
def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
async def document_sync(
|
||||
space_name: str,
|
||||
request: DocumentSyncRequest,
|
||||
service: Service = Depends(get_rag_service),
|
||||
):
|
||||
logger.info(f"Received params: {space_name}, {request}")
|
||||
try:
|
||||
knowledge_space_service.sync_knowledge_document(
|
||||
space_name=space_name, sync_request=request
|
||||
space = service.get({"name": space_name})
|
||||
if space is None:
|
||||
return Result.failed(code="E000X", msg=f"space {space_name} not exist")
|
||||
if request.doc_ids is None or len(request.doc_ids) == 0:
|
||||
return Result.failed(code="E000X", msg="doc_ids is None")
|
||||
sync_request = KnowledgeSyncRequest(
|
||||
doc_id=request.doc_ids[0],
|
||||
space_id=str(space.id),
|
||||
model_name=request.model_name,
|
||||
)
|
||||
return Result.succ([])
|
||||
sync_request.chunk_parameters = ChunkParameters(
|
||||
chunk_strategy="Automatic",
|
||||
chunk_size=request.chunk_size or 512,
|
||||
chunk_overlap=request.chunk_overlap or 50,
|
||||
)
|
||||
doc_ids = await service.sync_document(requests=[sync_request])
|
||||
return Result.succ(doc_ids)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
|
||||
@@ -32,13 +31,8 @@ from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
)
|
||||
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
|
||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from dbgpt.serve.rag.service.service import SyncStatus
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
@@ -199,186 +193,6 @@ class KnowledgeService:
|
||||
total = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||
return DocumentQueryResponse(data=data, total=total, page=page)
|
||||
|
||||
def batch_document_sync(
|
||||
self,
|
||||
space_name,
|
||||
sync_requests: List[KnowledgeSyncRequest],
|
||||
) -> List[int]:
|
||||
"""batch sync knowledge document chunk into vector store
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- sync_requests: List[KnowledgeSyncRequest]
|
||||
Returns:
|
||||
- List[int]: document ids
|
||||
"""
|
||||
doc_ids = []
|
||||
for sync_request in sync_requests:
|
||||
docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id])
|
||||
if len(docs) == 0:
|
||||
raise Exception(
|
||||
f"there are document called, doc_id: {sync_request.doc_id}"
|
||||
)
|
||||
doc = docs[0]
|
||||
if (
|
||||
doc.status == SyncStatus.RUNNING.name
|
||||
or doc.status == SyncStatus.FINISHED.name
|
||||
):
|
||||
raise Exception(
|
||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||
)
|
||||
chunk_parameters = sync_request.chunk_parameters
|
||||
if chunk_parameters.chunk_strategy != ChunkStrategy.CHUNK_BY_SIZE.name:
|
||||
space_context = self.get_space_context(space_name)
|
||||
chunk_parameters.chunk_size = (
|
||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_size"])
|
||||
)
|
||||
chunk_parameters.chunk_overlap = (
|
||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
self._sync_knowledge_document(space_name, doc, chunk_parameters)
|
||||
doc_ids.append(doc.id)
|
||||
return doc_ids
|
||||
|
||||
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
|
||||
"""sync knowledge document chunk into vector store
|
||||
Args:
|
||||
- space: Knowledge Space Name
|
||||
- sync_request: DocumentSyncRequest
|
||||
"""
|
||||
from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter
|
||||
|
||||
doc_ids = sync_request.doc_ids
|
||||
self.model_name = sync_request.model_name or CFG.LLM_MODEL
|
||||
for doc_id in doc_ids:
|
||||
query = KnowledgeDocumentEntity(id=doc_id)
|
||||
docs = knowledge_document_dao.get_documents(query)
|
||||
if len(docs) == 0:
|
||||
raise Exception(
|
||||
f"there are document called, doc_id: {sync_request.doc_id}"
|
||||
)
|
||||
doc = docs[0]
|
||||
if (
|
||||
doc.status == SyncStatus.RUNNING.name
|
||||
or doc.status == SyncStatus.FINISHED.name
|
||||
):
|
||||
raise Exception(
|
||||
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
|
||||
)
|
||||
|
||||
space_context = self.get_space_context(space_name)
|
||||
chunk_size = (
|
||||
CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_size"])
|
||||
)
|
||||
chunk_overlap = (
|
||||
CFG.KNOWLEDGE_CHUNK_OVERLAP
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
if sync_request.chunk_size:
|
||||
chunk_size = sync_request.chunk_size
|
||||
if sync_request.chunk_overlap:
|
||||
chunk_overlap = sync_request.chunk_overlap
|
||||
separators = sync_request.separators or None
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
|
||||
chunk_parameters = ChunkParameters(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
if separators and len(separators) > 1:
|
||||
raise ValueError(
|
||||
"SpacyTextSplitter do not support multipsle separators"
|
||||
)
|
||||
try:
|
||||
separator = "\n\n" if not separators else separators[0]
|
||||
text_splitter = SpacyTextSplitter(
|
||||
separator=separator,
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
except Exception:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
if sync_request.pre_separator:
|
||||
logger.info(f"Use preseparator, {sync_request.pre_separator}")
|
||||
text_splitter = PreTextSplitter(
|
||||
pre_separator=sync_request.pre_separator,
|
||||
text_splitter_impl=text_splitter,
|
||||
)
|
||||
chunk_parameters.text_splitter = text_splitter
|
||||
self._sync_knowledge_document(space_name, doc, chunk_parameters)
|
||||
return doc.id
|
||||
|
||||
def _sync_knowledge_document(
|
||||
self,
|
||||
space_name,
|
||||
doc: KnowledgeDocumentEntity,
|
||||
chunk_parameters: ChunkParameters,
|
||||
) -> List[Chunk]:
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
|
||||
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"invalid space name:{space_name}")
|
||||
space = spaces[0]
|
||||
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
|
||||
llm_client=self.llm_client,
|
||||
model_name=self.model_name,
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type, vector_store_config=config
|
||||
)
|
||||
knowledge = KnowledgeFactory.create(
|
||||
datasource=doc.content,
|
||||
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
|
||||
)
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embedding_fn,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
chunk_docs = assembler.get_chunks()
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
doc.gmt_modified = datetime.now()
|
||||
knowledge_document_dao.update_knowledge_document(doc)
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
return chunk_docs
|
||||
|
||||
async def document_summary(self, request: DocumentSummaryRequest):
|
||||
"""get document summary
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user