mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-26 12:04:39 +00:00
refactor: rag storage refactor (#2434)
This commit is contained in:
@@ -5,6 +5,7 @@ from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR, resolve_root_path
|
||||
from dbgpt.util.executor_utils import DefaultExecutorFactory
|
||||
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters
|
||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,6 +35,7 @@ def initialize_components(
|
||||
system_app.register(DefaultScheduler)
|
||||
system_app.register_instance(controller)
|
||||
system_app.register(ConnectorManager)
|
||||
system_app.register(StorageManager)
|
||||
|
||||
from dbgpt_serve.agent.hub.controller import module_plugin
|
||||
|
||||
|
@@ -7,16 +7,15 @@ from dbgpt.model.parameter import (
|
||||
ModelServiceConfig,
|
||||
)
|
||||
from dbgpt.storage.cache.manager import ModelCacheParameters
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.util.configure import HookConfig
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.parameter_utils import BaseParameters
|
||||
from dbgpt.util.tracer import TracerParameters
|
||||
from dbgpt.util.utils import LoggingParameters
|
||||
from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnectorParameters
|
||||
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
)
|
||||
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
|
||||
from dbgpt_ext.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig
|
||||
from dbgpt_serve.core import BaseServeConfig
|
||||
|
||||
|
||||
@@ -52,22 +51,22 @@ class SystemParameters:
|
||||
|
||||
@dataclass
|
||||
class StorageConfig(BaseParameters):
|
||||
vector: VectorStoreConfig = field(
|
||||
default_factory=VectorStoreConfig,
|
||||
vector: Optional[ChromaVectorConfig] = field(
|
||||
default_factory=lambda: ChromaVectorConfig(),
|
||||
metadata={
|
||||
"help": _("default vector type"),
|
||||
},
|
||||
)
|
||||
graph: BuiltinKnowledgeGraphConfig = field(
|
||||
default_factory=BuiltinKnowledgeGraphConfig,
|
||||
graph: Optional[TuGraphStoreConfig] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("default graph type"),
|
||||
},
|
||||
)
|
||||
full_text: BuiltinKnowledgeGraphConfig = field(
|
||||
default_factory=BuiltinKnowledgeGraphConfig,
|
||||
full_text: Optional[ElasticsearchStoreConfig] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("default graph type"),
|
||||
"help": _("default full text type"),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -93,7 +92,7 @@ class RagParameters(BaseParameters):
|
||||
default=10,
|
||||
metadata={"help": _("knowledge search top k")},
|
||||
)
|
||||
similarity_score_threshold: Optional[int] = field(
|
||||
similarity_score_threshold: Optional[float] = field(
|
||||
default=0.0,
|
||||
metadata={"help": _("knowledge search top similarity score")},
|
||||
)
|
||||
@@ -117,14 +116,86 @@ class RagParameters(BaseParameters):
|
||||
default_factory=lambda: StorageConfig(),
|
||||
metadata={"help": _("Storage configuration")},
|
||||
)
|
||||
graph_search_top_k: Optional[int] = field(
|
||||
default=3,
|
||||
knowledge_graph_chunk_search_top_k: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": _("knowledge graph search top k")},
|
||||
)
|
||||
graph_community_summary_enabled: Optional[bool] = field(
|
||||
kg_enable_summary: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": _("graph community summary enabled")},
|
||||
)
|
||||
llm_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": _("kg extract llm model")},
|
||||
)
|
||||
kg_extract_top_k: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": _("kg extract top k")},
|
||||
)
|
||||
kg_extract_score_threshold: Optional[float] = field(
|
||||
default=0.3,
|
||||
metadata={"help": _("kg extract score threshold")},
|
||||
)
|
||||
kg_community_top_k: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": _("kg community top k")},
|
||||
)
|
||||
kg_community_score_threshold: Optional[float] = field(
|
||||
default=0.3,
|
||||
metadata={"help": _("kg_community_score_threshold")},
|
||||
)
|
||||
kg_triplet_graph_enabled: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": _("kg_triplet_graph_enabled")},
|
||||
)
|
||||
kg_document_graph_enabled: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": _("kg_document_graph_enabled")},
|
||||
)
|
||||
kg_chunk_search_top_k: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": _("kg_chunk_search_top_k")},
|
||||
)
|
||||
kg_extraction_batch_size: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": _("kg_extraction_batch_size")},
|
||||
)
|
||||
kg_community_summary_batch_size: Optional[int] = field(
|
||||
default=20,
|
||||
metadata={"help": _("kg_community_summary_batch_size")},
|
||||
)
|
||||
kg_embedding_batch_size: Optional[int] = field(
|
||||
default=20,
|
||||
metadata={"help": _("kg_embedding_batch_size")},
|
||||
)
|
||||
kg_similarity_top_k: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": _("kg_similarity_top_k")},
|
||||
)
|
||||
kg_similarity_score_threshold: Optional[float] = field(
|
||||
default=0.7,
|
||||
metadata={"help": _("kg_similarity_score_threshold")},
|
||||
)
|
||||
kg_enable_text_search: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": _("kg_enable_text_search")},
|
||||
)
|
||||
kg_text2gql_model_enabled: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": _("kg_text2gql_model_enabled")},
|
||||
)
|
||||
kg_text2gql_model_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": _("text2gql_model_name")},
|
||||
)
|
||||
bm25_k1: Optional[float] = field(
|
||||
default=2.0,
|
||||
metadata={"help": _("bm25_k1")},
|
||||
)
|
||||
bm25_b: Optional[float] = field(
|
||||
default=0.75,
|
||||
metadata={"help": _("bm25_b")},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -257,6 +257,7 @@ def run_webserver(config_file: str):
|
||||
def scan_configs():
|
||||
from dbgpt.model import scan_model_providers
|
||||
from dbgpt_app.initialization.serve_initialization import scan_serve_configs
|
||||
from dbgpt_ext.storage import scan_storage_configs
|
||||
from dbgpt_serve.datasource.manages.connector_manager import ConnectorManager
|
||||
|
||||
cm = ConnectorManager(system_app)
|
||||
@@ -266,6 +267,8 @@ def scan_configs():
|
||||
scan_model_providers()
|
||||
# Register all serve configs
|
||||
scan_serve_configs()
|
||||
# Register all storage configs
|
||||
scan_storage_configs()
|
||||
|
||||
|
||||
def load_config(config_file: str = None) -> ApplicationConfig:
|
||||
|
@@ -12,10 +12,8 @@ from dbgpt.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.retriever import BaseRetriever
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
from dbgpt_app.knowledge.request.request import (
|
||||
@@ -49,8 +47,10 @@ from dbgpt_serve.rag.api.schemas import (
|
||||
KnowledgeStorageType,
|
||||
KnowledgeSyncRequest,
|
||||
)
|
||||
from dbgpt_serve.rag.connector import VectorStoreConnector
|
||||
|
||||
# from dbgpt_serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt_serve.rag.service.service import Service
|
||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -505,22 +505,12 @@ def chunk_edit(
|
||||
|
||||
|
||||
@router.post("/knowledge/{vector_name}/query")
|
||||
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
def similarity_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
print(f"Received params: {space_name}, {query_request}")
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
config = VectorStoreConfig(
|
||||
name=space_name,
|
||||
embedding_fn=embedding_factory.create(),
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
system_app=CFG.SYSTEM_APP,
|
||||
)
|
||||
storage_manager = StorageManager.get_instance(CFG.SYSTEM_APP)
|
||||
vector_store_connector = storage_manager.create_vector_store(index_name=space_name)
|
||||
retriever = EmbeddingRetriever(
|
||||
top_k=query_request.top_k, index_store=vector_store_connector.index_client
|
||||
top_k=query_request.top_k, index_store=vector_store_connector
|
||||
)
|
||||
chunks = retriever.retrieve(query_request.query)
|
||||
res = [
|
||||
|
@@ -12,10 +12,8 @@ from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
from dbgpt_app.knowledge.request.request import (
|
||||
@@ -36,7 +34,6 @@ from dbgpt_app.knowledge.request.response import (
|
||||
from dbgpt_ext.rag.assembler.summary import SummaryAssembler
|
||||
from dbgpt_ext.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt_ext.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt_serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt_serve.rag.models.chunk_db import DocumentChunkDao, DocumentChunkEntity
|
||||
from dbgpt_serve.rag.models.document_db import (
|
||||
KnowledgeDocumentDao,
|
||||
@@ -45,6 +42,7 @@ from dbgpt_serve.rag.models.document_db import (
|
||||
from dbgpt_serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from dbgpt_serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
from dbgpt_serve.rag.service.service import SyncStatus
|
||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||
|
||||
knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
@@ -82,6 +80,14 @@ class KnowledgeService:
|
||||
rag_config = CFG.SYSTEM_APP.config.configs.get("app_config").rag
|
||||
return rag_config
|
||||
|
||||
@property
|
||||
def storage_manager(self):
|
||||
return StorageManager.get_instance(CFG.SYSTEM_APP)
|
||||
|
||||
@property
|
||||
def system_app(self):
|
||||
return CFG.SYSTEM_APP
|
||||
|
||||
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
"""create knowledge space
|
||||
Args:
|
||||
@@ -91,7 +97,7 @@ class KnowledgeService:
|
||||
name=request.name,
|
||||
)
|
||||
if request.vector_type == "VectorStore":
|
||||
request.vector_type = self.rag_config.storage.vector.get("type")
|
||||
request.vector_type = self.rag_config.storage.vector.get_type_value()
|
||||
if request.vector_type == "KnowledgeGraph":
|
||||
knowledge_space_name_pattern = r"^[a-zA-Z0-9\u4e00-\u9fa5]+$"
|
||||
if not re.match(knowledge_space_name_pattern, request.name):
|
||||
@@ -412,28 +418,15 @@ class KnowledgeService:
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"invalid space name:{space_name}")
|
||||
space = spaces[0]
|
||||
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create()
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
llm_client=self.llm_client,
|
||||
model_name=None,
|
||||
)
|
||||
if space.domain_type == DOMAIN_TYPE_FINANCIAL_REPORT:
|
||||
conn_manager = CFG.local_db_manager
|
||||
conn_manager.delete_db(f"{space.name}_fin_report")
|
||||
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=CFG.SYSTEM_APP,
|
||||
storage_connector = self.storage_manager.get_storage_connector(
|
||||
index_name=space_name, storage_type=space.vector_type
|
||||
)
|
||||
# delete vectors
|
||||
vector_store_connector.delete_vector_name(space.name)
|
||||
storage_connector.delete_vector_name(space.name)
|
||||
document_query = KnowledgeDocumentEntity(space=space.name)
|
||||
# delete chunks
|
||||
documents = knowledge_document_dao.get_documents(document_query)
|
||||
@@ -462,23 +455,11 @@ class KnowledgeService:
|
||||
|
||||
vector_ids = documents[0].vector_ids
|
||||
if vector_ids is not None:
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create()
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
llm_client=self.llm_client,
|
||||
model_name=None,
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=CFG.SYSTEM_APP,
|
||||
storage_connector = self.storage_manager.get_storage_connector(
|
||||
index_name=space_name, storage_type=space.vector_type
|
||||
)
|
||||
# delete vector by ids
|
||||
vector_store_connector.delete_by_ids(vector_ids)
|
||||
storage_connector.delete_by_ids(vector_ids)
|
||||
# delete chunks
|
||||
document_chunk_dao.raw_delete(documents[0].id)
|
||||
# delete document
|
||||
@@ -628,29 +609,12 @@ class KnowledgeService:
|
||||
return chat
|
||||
|
||||
def query_graph(self, space_name, limit):
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create()
|
||||
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"invalid space name:{space_name}")
|
||||
space = spaces[0]
|
||||
print(CFG.LLM_MODEL)
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
|
||||
llm_client=self.llm_client,
|
||||
model_name=None,
|
||||
)
|
||||
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=CFG.SYSTEM_APP,
|
||||
)
|
||||
graph = vector_store_connector.client.query_graph(limit=limit)
|
||||
# space = spaces[0]
|
||||
graph_store = self.storage_manager.create_kg_store(index_name=space_name)
|
||||
graph = graph_store.query_graph(limit=limit)
|
||||
res = {"nodes": [], "edges": []}
|
||||
for node in graph.vertices():
|
||||
res["nodes"].append(
|
||||
|
@@ -234,7 +234,7 @@ class ChatKnowledge(BaseChat):
|
||||
from dbgpt_ext.storage import __knowledge_graph__ as graph_storages
|
||||
|
||||
if spaces[0].vector_type in graph_storages:
|
||||
return self.rag_config.graph_search_top_k
|
||||
return self.rag_config.kg_chunk_search_top_k
|
||||
|
||||
return self.rag_config.similarity_top_k
|
||||
|
||||
|
Reference in New Issue
Block a user