refactor: rag storage refactor (#2434)

This commit is contained in:
Aries-ckt
2025-03-17 14:15:21 +08:00
committed by GitHub
parent 34d86d693c
commit fc3fe6b725
52 changed files with 1134 additions and 797 deletions

View File

@@ -5,6 +5,7 @@ from dbgpt.component import SystemApp
from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR, resolve_root_path
from dbgpt.util.executor_utils import DefaultExecutorFactory
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters
from dbgpt_serve.rag.storage_manager import StorageManager
logger = logging.getLogger(__name__)
@@ -34,6 +35,7 @@ def initialize_components(
system_app.register(DefaultScheduler)
system_app.register_instance(controller)
system_app.register(ConnectorManager)
system_app.register(StorageManager)
from dbgpt_serve.agent.hub.controller import module_plugin

View File

@@ -7,16 +7,15 @@ from dbgpt.model.parameter import (
ModelServiceConfig,
)
from dbgpt.storage.cache.manager import ModelCacheParameters
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.util.configure import HookConfig
from dbgpt.util.i18n_utils import _
from dbgpt.util.parameter_utils import BaseParameters
from dbgpt.util.tracer import TracerParameters
from dbgpt.util.utils import LoggingParameters
from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnectorParameters
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraphConfig,
)
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
from dbgpt_ext.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig
from dbgpt_serve.core import BaseServeConfig
@@ -52,22 +51,22 @@ class SystemParameters:
@dataclass
class StorageConfig(BaseParameters):
vector: VectorStoreConfig = field(
default_factory=VectorStoreConfig,
vector: Optional[ChromaVectorConfig] = field(
default_factory=lambda: ChromaVectorConfig(),
metadata={
"help": _("default vector type"),
},
)
graph: BuiltinKnowledgeGraphConfig = field(
default_factory=BuiltinKnowledgeGraphConfig,
graph: Optional[TuGraphStoreConfig] = field(
default=None,
metadata={
"help": _("default graph type"),
},
)
full_text: BuiltinKnowledgeGraphConfig = field(
default_factory=BuiltinKnowledgeGraphConfig,
full_text: Optional[ElasticsearchStoreConfig] = field(
default=None,
metadata={
"help": _("default graph type"),
"help": _("default full text type"),
},
)
@@ -93,7 +92,7 @@ class RagParameters(BaseParameters):
default=10,
metadata={"help": _("knowledge search top k")},
)
similarity_score_threshold: Optional[int] = field(
similarity_score_threshold: Optional[float] = field(
default=0.0,
metadata={"help": _("knowledge search top similarity score")},
)
@@ -117,14 +116,86 @@ class RagParameters(BaseParameters):
default_factory=lambda: StorageConfig(),
metadata={"help": _("Storage configuration")},
)
graph_search_top_k: Optional[int] = field(
default=3,
knowledge_graph_chunk_search_top_k: Optional[int] = field(
default=5,
metadata={"help": _("knowledge graph search top k")},
)
graph_community_summary_enabled: Optional[bool] = field(
kg_enable_summary: Optional[bool] = field(
default=False,
metadata={"help": _("graph community summary enabled")},
)
llm_model: Optional[str] = field(
default=None,
metadata={"help": _("kg extract llm model")},
)
kg_extract_top_k: Optional[int] = field(
default=5,
metadata={"help": _("kg extract top k")},
)
kg_extract_score_threshold: Optional[float] = field(
default=0.3,
metadata={"help": _("kg extract score threshold")},
)
kg_community_top_k: Optional[int] = field(
default=50,
metadata={"help": _("kg community top k")},
)
kg_community_score_threshold: Optional[float] = field(
default=0.3,
metadata={"help": _("kg_community_score_threshold")},
)
kg_triplet_graph_enabled: Optional[bool] = field(
default=True,
metadata={"help": _("kg_triplet_graph_enabled")},
)
kg_document_graph_enabled: Optional[bool] = field(
default=True,
metadata={"help": _("kg_document_graph_enabled")},
)
kg_chunk_search_top_k: Optional[int] = field(
default=5,
metadata={"help": _("kg_chunk_search_top_k")},
)
kg_extraction_batch_size: Optional[int] = field(
default=3,
metadata={"help": _("kg_extraction_batch_size")},
)
kg_community_summary_batch_size: Optional[int] = field(
default=20,
metadata={"help": _("kg_community_summary_batch_size")},
)
kg_embedding_batch_size: Optional[int] = field(
default=20,
metadata={"help": _("kg_embedding_batch_size")},
)
kg_similarity_top_k: Optional[int] = field(
default=5,
metadata={"help": _("kg_similarity_top_k")},
)
kg_similarity_score_threshold: Optional[float] = field(
default=0.7,
metadata={"help": _("kg_similarity_score_threshold")},
)
kg_enable_text_search: Optional[bool] = field(
default=False,
metadata={"help": _("kg_enable_text_search")},
)
kg_text2gql_model_enabled: Optional[bool] = field(
default=False,
metadata={"help": _("kg_text2gql_model_enabled")},
)
kg_text2gql_model_name: Optional[str] = field(
default=None,
metadata={"help": _("text2gql_model_name")},
)
bm25_k1: Optional[float] = field(
default=2.0,
metadata={"help": _("bm25_k1")},
)
bm25_b: Optional[float] = field(
default=0.75,
metadata={"help": _("bm25_b")},
)
@dataclass

View File

@@ -257,6 +257,7 @@ def run_webserver(config_file: str):
def scan_configs():
from dbgpt.model import scan_model_providers
from dbgpt_app.initialization.serve_initialization import scan_serve_configs
from dbgpt_ext.storage import scan_storage_configs
from dbgpt_serve.datasource.manages.connector_manager import ConnectorManager
cm = ConnectorManager(system_app)
@@ -266,6 +267,8 @@ def scan_configs():
scan_model_providers()
# Register all serve configs
scan_serve_configs()
# Register all storage configs
scan_storage_configs()
def load_config(config_file: str = None) -> ApplicationConfig:

View File

@@ -12,10 +12,8 @@ from dbgpt.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever import BaseRetriever
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt_app.knowledge.request.request import (
@@ -49,8 +47,10 @@ from dbgpt_serve.rag.api.schemas import (
KnowledgeStorageType,
KnowledgeSyncRequest,
)
from dbgpt_serve.rag.connector import VectorStoreConnector
# from dbgpt_serve.rag.connector import VectorStoreConnector
from dbgpt_serve.rag.service.service import Service
from dbgpt_serve.rag.storage_manager import StorageManager
logger = logging.getLogger(__name__)
@@ -505,22 +505,12 @@ def chunk_edit(
@router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
def similarity_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}")
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
config = VectorStoreConfig(
name=space_name,
embedding_fn=embedding_factory.create(),
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
system_app=CFG.SYSTEM_APP,
)
storage_manager = StorageManager.get_instance(CFG.SYSTEM_APP)
vector_store_connector = storage_manager.create_vector_store(index_name=space_name)
retriever = EmbeddingRetriever(
top_k=query_request.top_k, index_store=vector_store_connector.index_client
top_k=query_request.top_k, index_store=vector_store_connector
)
chunks = retriever.retrieve(query_request.query)
res = [

View File

@@ -12,10 +12,8 @@ from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import LLMClient
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace
from dbgpt_app.knowledge.request.request import (
@@ -36,7 +34,6 @@ from dbgpt_app.knowledge.request.response import (
from dbgpt_ext.rag.assembler.summary import SummaryAssembler
from dbgpt_ext.rag.chunk_manager import ChunkParameters
from dbgpt_ext.rag.knowledge.factory import KnowledgeFactory
from dbgpt_serve.rag.connector import VectorStoreConnector
from dbgpt_serve.rag.models.chunk_db import DocumentChunkDao, DocumentChunkEntity
from dbgpt_serve.rag.models.document_db import (
KnowledgeDocumentDao,
@@ -45,6 +42,7 @@ from dbgpt_serve.rag.models.document_db import (
from dbgpt_serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt_serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt_serve.rag.service.service import SyncStatus
from dbgpt_serve.rag.storage_manager import StorageManager
knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
@@ -82,6 +80,14 @@ class KnowledgeService:
rag_config = CFG.SYSTEM_APP.config.configs.get("app_config").rag
return rag_config
@property
def storage_manager(self):
return StorageManager.get_instance(CFG.SYSTEM_APP)
@property
def system_app(self):
return CFG.SYSTEM_APP
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
"""create knowledge space
Args:
@@ -91,7 +97,7 @@ class KnowledgeService:
name=request.name,
)
if request.vector_type == "VectorStore":
request.vector_type = self.rag_config.storage.vector.get("type")
request.vector_type = self.rag_config.storage.vector.get_type_value()
if request.vector_type == "KnowledgeGraph":
knowledge_space_name_pattern = r"^[a-zA-Z0-9\u4e00-\u9fa5]+$"
if not re.match(knowledge_space_name_pattern, request.name):
@@ -412,28 +418,15 @@ class KnowledgeService:
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create()
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
model_name=None,
)
if space.domain_type == DOMAIN_TYPE_FINANCIAL_REPORT:
conn_manager = CFG.local_db_manager
conn_manager.delete_db(f"{space.name}_fin_report")
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type,
vector_store_config=config,
system_app=CFG.SYSTEM_APP,
storage_connector = self.storage_manager.get_storage_connector(
index_name=space_name, storage_type=space.vector_type
)
# delete vectors
vector_store_connector.delete_vector_name(space.name)
storage_connector.delete_vector_name(space.name)
document_query = KnowledgeDocumentEntity(space=space.name)
# delete chunks
documents = knowledge_document_dao.get_documents(document_query)
@@ -462,23 +455,11 @@ class KnowledgeService:
vector_ids = documents[0].vector_ids
if vector_ids is not None:
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create()
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type,
vector_store_config=config,
system_app=CFG.SYSTEM_APP,
storage_connector = self.storage_manager.get_storage_connector(
index_name=space_name, storage_type=space.vector_type
)
# delete vector by ids
vector_store_connector.delete_by_ids(vector_ids)
storage_connector.delete_by_ids(vector_ids)
# delete chunks
document_chunk_dao.raw_delete(documents[0].id)
# delete document
@@ -628,29 +609,12 @@ class KnowledgeService:
return chat
def query_graph(self, space_name, limit):
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create()
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
print(CFG.LLM_MODEL)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=None,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type,
vector_store_config=config,
system_app=CFG.SYSTEM_APP,
)
graph = vector_store_connector.client.query_graph(limit=limit)
# space = spaces[0]
graph_store = self.storage_manager.create_kg_store(index_name=space_name)
graph = graph_store.query_graph(limit=limit)
res = {"nodes": [], "edges": []}
for node in graph.vertices():
res["nodes"].append(

View File

@@ -234,7 +234,7 @@ class ChatKnowledge(BaseChat):
from dbgpt_ext.storage import __knowledge_graph__ as graph_storages
if spaces[0].vector_type in graph_storages:
return self.rag_config.graph_search_top_k
return self.rag_config.kg_chunk_search_top_k
return self.rag_config.similarity_top_k