mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 09:06:55 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
|
||||
@@ -13,10 +14,10 @@ from dbgpt.configs.model_config import (
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeQueryResponse,
|
||||
@@ -27,9 +28,14 @@ from dbgpt.app.knowledge.request.request import (
|
||||
SpaceArgumentRequest,
|
||||
EntityExtractRequest,
|
||||
DocumentSummaryRequest,
|
||||
KnowledgeSyncRequest,
|
||||
)
|
||||
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.util.tracer import root_tracer, SpanType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -103,6 +109,39 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
||||
return Result.failed(code="E000X", msg=f"document add error {e}")
|
||||
|
||||
|
||||
@router.get("/knowledge/document/chunkstrategies")
|
||||
def chunk_strategies():
|
||||
"""Get chunk strategies"""
|
||||
print(f"/document/chunkstrategies:")
|
||||
try:
|
||||
return Result.succ(
|
||||
[
|
||||
{
|
||||
"strategy": strategy.name,
|
||||
"name": strategy.value[2],
|
||||
"description": strategy.value[3],
|
||||
"parameters": strategy.value[1],
|
||||
"suffix": [
|
||||
knowledge.document_type().value
|
||||
for knowledge in KnowledgeFactory.subclasses()
|
||||
if strategy in knowledge.support_chunk_strategy()
|
||||
and knowledge.document_type() is not None
|
||||
],
|
||||
"type": set(
|
||||
[
|
||||
knowledge.type().value
|
||||
for knowledge in KnowledgeFactory.subclasses()
|
||||
if strategy in knowledge.support_chunk_strategy()
|
||||
]
|
||||
),
|
||||
}
|
||||
for strategy in ChunkStrategy
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"chunk strategies error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/list")
|
||||
def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
@@ -189,6 +228,18 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync_batch")
|
||||
def batch_document_sync(space_name: str, request: List[KnowledgeSyncRequest]):
|
||||
logger.info(f"Received params: {space_name}, {request}")
|
||||
try:
|
||||
doc_ids = knowledge_space_service.batch_document_sync(
|
||||
space_name=space_name, sync_requests=request
|
||||
)
|
||||
return Result.succ({"tasks": doc_ids})
|
||||
except Exception as e:
|
||||
return Result.failed(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/chunk/list")
|
||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
@@ -204,15 +255,23 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={"vector_store_name": space_name},
|
||||
embedding_factory=embedding_factory,
|
||||
config = VectorStoreConfig(
|
||||
name=space_name,
|
||||
embedding_fn=embedding_factory.create(
|
||||
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
),
|
||||
)
|
||||
docs = client.similar_search(query_request.query, query_request.top_k)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
retriever = EmbeddingRetriever(
|
||||
top_k=query_request.top_k, vector_store_connector=vector_store_connector
|
||||
)
|
||||
chunks = retriever.retrieve(query_request.query)
|
||||
res = [
|
||||
KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"])
|
||||
for d in docs
|
||||
KnowledgeQueryResponse(text=d.content, source=d.metadata["source"])
|
||||
for d in chunks
|
||||
]
|
||||
return {"response": res}
|
||||
|
||||
@@ -254,7 +313,7 @@ async def entity_extract(request: EntityExtractRequest):
|
||||
logger.info(f"Received params: {request}")
|
||||
try:
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt._private.chat_util import llm_chat_response_nostream
|
||||
from dbgpt.util.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
|
Reference in New Issue
Block a user