mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-20 00:54:43 +00:00
refactor: rag storage refactor (#2434)
This commit is contained in:
@@ -8,11 +8,12 @@ from dbgpt.component import SystemApp
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
from dbgpt_ext.rag import ChunkParameters
|
||||
from dbgpt_ext.rag.summary.gdbms_db_summary import GdbmsSummary
|
||||
from dbgpt_ext.rag.summary.rdbms_db_summary import RdbmsSummary
|
||||
from dbgpt_serve.datasource.manages import ConnectorManager
|
||||
from dbgpt_serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,8 +69,8 @@ class DBSummaryClient:
|
||||
)
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=topk,
|
||||
table_vector_store_connector=table_vector_connector.index_client,
|
||||
field_vector_store_connector=field_vector_connector.index_client,
|
||||
table_vector_store_connector=table_vector_connector,
|
||||
field_vector_store_connector=field_vector_connector,
|
||||
separator="--table-field-separator--",
|
||||
)
|
||||
|
||||
@@ -116,8 +117,8 @@ class DBSummaryClient:
|
||||
)
|
||||
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=db_summary_client.db,
|
||||
table_vector_store_connector=table_vector_connector.index_client,
|
||||
field_vector_store_connector=field_vector_connector.index_client,
|
||||
table_vector_store_connector=table_vector_connector,
|
||||
field_vector_store_connector=field_vector_connector,
|
||||
chunk_parameters=chunk_parameters,
|
||||
max_seq_length=self.app_config.service.web.embedding_model_max_seq_len,
|
||||
)
|
||||
@@ -157,23 +158,14 @@ class DBSummaryClient:
|
||||
|
||||
def _get_vector_connector_by_db(
|
||||
self, dbname
|
||||
) -> Tuple[VectorStoreConnector, VectorStoreConnector]:
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
) -> Tuple[VectorStoreBase, VectorStoreBase]:
|
||||
vector_store_name = dbname + "_profile"
|
||||
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
table_vector_connector = VectorStoreConnector.from_default(
|
||||
self.storage_config.vector.get("type"),
|
||||
self.embeddings,
|
||||
vector_store_config=table_vector_store_config,
|
||||
system_app=self.system_app,
|
||||
storage_manager = StorageManager.get_instance(self.system_app)
|
||||
table_vector_store = storage_manager.create_vector_store(
|
||||
index_name=vector_store_name
|
||||
)
|
||||
field_vector_store_name = dbname + "_profile_field"
|
||||
field_vector_store_config = VectorStoreConfig(name=field_vector_store_name)
|
||||
field_vector_connector = VectorStoreConnector.from_default(
|
||||
self.storage_config.vector.get("type"),
|
||||
self.embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
system_app=self.system_app,
|
||||
field_vector_store = storage_manager.create_vector_store(
|
||||
index_name=field_vector_store_name
|
||||
)
|
||||
return table_vector_connector, field_vector_connector
|
||||
return table_vector_store, field_vector_store
|
||||
|
@@ -10,7 +10,6 @@ from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.datasource.parameter import BaseDatasourceParameters
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
from dbgpt_ext.datasource.schema import DBType
|
||||
from dbgpt_serve.core import BaseService, ResourceTypes
|
||||
@@ -19,8 +18,8 @@ from dbgpt_serve.datasource.manages.connect_config_db import (
|
||||
ConnectConfigDao,
|
||||
ConnectConfigEntity,
|
||||
)
|
||||
from dbgpt_serve.rag.connector import VectorStoreConnector
|
||||
|
||||
from ...rag.storage_manager import StorageManager
|
||||
from ..api.schemas import (
|
||||
DatasourceCreateRequest,
|
||||
DatasourceQueryResponse,
|
||||
@@ -93,6 +92,12 @@ class Service(
|
||||
raise ValueError("SYSTEM_APP is not set")
|
||||
return ConnectorManager.get_instance(self._system_app)
|
||||
|
||||
@property
|
||||
def storage_manager(self) -> StorageManager:
|
||||
if not self._system_app:
|
||||
raise ValueError("SYSTEM_APP is not set")
|
||||
return StorageManager.get_instance(self._system_app)
|
||||
|
||||
def create(
|
||||
self, request: Union[DatasourceCreateRequest, DatasourceServeRequest]
|
||||
) -> DatasourceQueryResponse:
|
||||
@@ -229,13 +234,10 @@ class Service(
|
||||
"""
|
||||
db_config = self._dao.get_one({"id": datasource_id})
|
||||
vector_name = db_config.db_name + "_profile"
|
||||
vector_store_config = VectorStoreConfig(name=vector_name)
|
||||
_vector_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=vector_store_config,
|
||||
system_app=self._system_app,
|
||||
vector_connector = self.storage_manager.create_vector_store(
|
||||
index_name=vector_name
|
||||
)
|
||||
_vector_connector.delete_vector_name(vector_name)
|
||||
vector_connector.delete_vector_name(vector_name)
|
||||
if db_config:
|
||||
self._dao.delete({"id": datasource_id})
|
||||
return db_config
|
||||
@@ -300,13 +302,10 @@ class Service(
|
||||
"""
|
||||
db_config = self._dao.get_one({"id": datasource_id})
|
||||
vector_name = db_config.db_name + "_profile"
|
||||
vector_store_config = VectorStoreConfig(name=vector_name)
|
||||
_vector_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=vector_store_config,
|
||||
system_app=self._system_app,
|
||||
vector_connector = self.storage_manager.create_vector_store(
|
||||
index_name=vector_name
|
||||
)
|
||||
_vector_connector.delete_vector_name(vector_name)
|
||||
vector_connector.delete_vector_name(vector_name)
|
||||
self._db_summary_client.db_summary_embedding(
|
||||
db_config.db_name, db_config.db_type
|
||||
)
|
||||
|
@@ -15,15 +15,16 @@ from dbgpt.rag.evaluation import RetrieverEvaluator
|
||||
from dbgpt.rag.evaluation.answer import AnswerRelevancyMetric
|
||||
from dbgpt.rag.evaluation.retriever import RetrieverSimilarityMetric
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt_serve.rag.operators.knowledge_space import SpaceRetrieverOperator
|
||||
|
||||
from ...agent.agents.controller import multi_agents
|
||||
from ...agent.evaluation.evaluation import AgentEvaluator, AgentOutputOperator
|
||||
from ...core import BaseService
|
||||
from ...prompt.service.service import Service as PromptService
|
||||
from ...rag.connector import VectorStoreConnector
|
||||
|
||||
# from ...rag.connector import VectorStoreConnector
|
||||
from ...rag.service.service import Service as RagService
|
||||
from ...rag.storage_manager import StorageManager
|
||||
from ..api.schemas import EvaluateServeRequest, EvaluateServeResponse, EvaluationScene
|
||||
from ..config import SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
@@ -64,6 +65,10 @@ class Service(BaseService[ServeEntity, EvaluateServeRequest, EvaluateServeRespon
|
||||
"""
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def storage_manager(self):
|
||||
return StorageManager.get_instance(self._system_app)
|
||||
|
||||
@property
|
||||
def dao(self) -> BaseDao[ServeEntity, EvaluateServeRequest, EvaluateServeResponse]:
|
||||
"""Returns the internal DAO."""
|
||||
@@ -104,17 +109,13 @@ class Service(BaseService[ServeEntity, EvaluateServeRequest, EvaluateServeRespon
|
||||
)
|
||||
embeddings = embedding_factory.create()
|
||||
|
||||
config = VectorStoreConfig(
|
||||
name=scene_value,
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
space = self.rag_service.get({"space_id": str(scene_value)})
|
||||
if not space:
|
||||
raise ValueError(f"Space {scene_value} not found")
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=self._system_app,
|
||||
storage_connector = self.storage_manager.get_storage_connector(
|
||||
index_name=space.name,
|
||||
storage_type=space.vector_type,
|
||||
llm_model=context.get("llm_model"),
|
||||
)
|
||||
evaluator = RetrieverEvaluator(
|
||||
operator_cls=SpaceRetrieverOperator,
|
||||
@@ -122,7 +123,7 @@ class Service(BaseService[ServeEntity, EvaluateServeRequest, EvaluateServeRespon
|
||||
operator_kwargs={
|
||||
"space_id": str(scene_value),
|
||||
"top_k": self._serve_config.similarity_top_k,
|
||||
"vector_store_connector": vector_store_connector,
|
||||
"vector_store_connector": storage_connector,
|
||||
},
|
||||
)
|
||||
metrics = []
|
||||
|
@@ -11,6 +11,7 @@ from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.storage.base import IndexStoreConfig
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt_ext.storage import __document_store__ as supported_full_tet_list
|
||||
from dbgpt_ext.storage import __knowledge_graph__ as supported_kg_store_list
|
||||
from dbgpt_ext.storage import __vector_store__ as supported_vector_store_list
|
||||
|
||||
@@ -70,19 +71,14 @@ class VectorStoreConnector:
|
||||
self._vector_store_type = vector_store_type
|
||||
self._embeddings = vector_store_config.embedding_fn
|
||||
|
||||
config_dict = {}
|
||||
storage_config = self.app_config.rag.storage
|
||||
if vector_store_type in supported_vector_store_list:
|
||||
config_dict = storage_config.vector
|
||||
elif vector_store_type in supported_kg_store_list:
|
||||
config_dict = storage_config.graph
|
||||
config_dict = self._adapt_storage_config(vector_store_type).to_dict()
|
||||
for key in vector_store_config.to_dict().keys():
|
||||
value = getattr(vector_store_config, key)
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
for key, value in vector_store_config.model_extra.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
# for key, value in vector_store_config.model_extra.items():
|
||||
# if value is not None:
|
||||
# config_dict[key] = value
|
||||
config = self.config_class(**config_dict)
|
||||
try:
|
||||
if vector_store_type in pools and config.name in pools[vector_store_type]:
|
||||
@@ -97,10 +93,10 @@ class VectorStoreConnector:
|
||||
def _rewrite_index_store_type(self, index_store_type):
|
||||
# Rewrite Knowledge Graph Type
|
||||
if self.app_config.rag.storage.graph:
|
||||
graph_dict = self.app_config.rag.storage.graph
|
||||
graph_config = self.app_config.rag.storage.graph
|
||||
if (
|
||||
isinstance(graph_dict, dict)
|
||||
and graph_dict.get("enable_summary", "false").lower() == "true"
|
||||
hasattr(graph_config, "enable_summary")
|
||||
and graph_config.enable_summary.lower() == "true"
|
||||
):
|
||||
if index_store_type == "KnowledgeGraph":
|
||||
return "CommunitySummaryKnowledgeGraph"
|
||||
@@ -286,3 +282,14 @@ class VectorStoreConnector:
|
||||
for cls_name in rag_storages:
|
||||
store_cls, config_cls = _select_rag_storage(cls_name)
|
||||
connector[cls_name] = (store_cls, config_cls)
|
||||
|
||||
def _adapt_storage_config(self, vector_store_type):
|
||||
"""Adapt storage config."""
|
||||
storage_config = self.app_config.rag.storage
|
||||
if vector_store_type in supported_vector_store_list:
|
||||
return storage_config.vector
|
||||
elif vector_store_type in supported_kg_store_list:
|
||||
return storage_config.graph
|
||||
elif vector_store_type in supported_full_tet_list:
|
||||
return storage_config.full_text
|
||||
raise ValueError(f"storage type {vector_store_type} not supported")
|
||||
|
@@ -100,6 +100,7 @@ class SpaceRetrieverOperator(RetrieverOperator[IN, OUT]):
|
||||
space_retriever = KnowledgeSpaceRetriever(
|
||||
space_id=self._space_id,
|
||||
top_k=self._top_k,
|
||||
system_app=self._service.system_app,
|
||||
)
|
||||
if isinstance(query, str):
|
||||
candidates = space_retriever.retrieve_with_scores(query, self._recall_score)
|
||||
|
@@ -2,17 +2,15 @@ from typing import List, Optional
|
||||
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt_serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt_serve.rag.models.models import KnowledgeSpaceDao
|
||||
from dbgpt_serve.rag.retriever.qa_retriever import QARetriever
|
||||
from dbgpt_serve.rag.retriever.retriever_chain import RetrieverChain
|
||||
from dbgpt_serve.rag.storage_manager import StorageManager
|
||||
|
||||
|
||||
class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
@@ -49,7 +47,6 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create()
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
space_dao = KnowledgeSpaceDao()
|
||||
space = space_dao.get_one({"id": space_id})
|
||||
@@ -57,21 +54,10 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
space = space_dao.get_one({"name": space_id})
|
||||
if space is None:
|
||||
raise ValueError(f"Knowledge space {space_id} not found")
|
||||
worker_manager = self._system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
llm_client = DefaultLLMClient(worker_manager=worker_manager)
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
llm_client=llm_client,
|
||||
llm_model=self._llm_model,
|
||||
)
|
||||
|
||||
self._vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=self._system_app,
|
||||
storage_connector = self.storage_manager.get_storage_connector(
|
||||
space.name,
|
||||
space.vector_type,
|
||||
self._llm_model,
|
||||
)
|
||||
self._executor = self._system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
@@ -86,7 +72,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
system_app=system_app,
|
||||
),
|
||||
EmbeddingRetriever(
|
||||
index_store=self._vector_store_connector.index_client,
|
||||
index_store=storage_connector,
|
||||
top_k=self._top_k,
|
||||
query_rewrite=self._query_rewrite,
|
||||
rerank=self._rerank,
|
||||
@@ -95,6 +81,10 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
executor=self._executor,
|
||||
)
|
||||
|
||||
@property
|
||||
def storage_manager(self):
|
||||
return StorageManager.get_instance(self._system_app)
|
||||
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
|
@@ -18,13 +18,11 @@ from dbgpt.configs.model_config import (
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.embedding import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import ChunkStrategy, KnowledgeType
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
from dbgpt.util.string_utils import remove_trailing_punctuation
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
@@ -33,7 +31,6 @@ from dbgpt_ext.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt_ext.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt_ext.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt_serve.core import BaseService
|
||||
from dbgpt_serve.rag.connector import VectorStoreConnector
|
||||
|
||||
from ..api.schemas import (
|
||||
ChunkServeRequest,
|
||||
@@ -52,6 +49,7 @@ from ..models.document_db import (
|
||||
)
|
||||
from ..models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from ..retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
from ..storage_manager import StorageManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,6 +94,10 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
self._chunk_dao = self._chunk_dao or DocumentChunkDao()
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def storage_manager(self):
|
||||
return StorageManager.get_instance(self._system_app)
|
||||
|
||||
@property
|
||||
def dao(
|
||||
self,
|
||||
@@ -286,14 +288,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
space = self.get(query_request)
|
||||
if space is None:
|
||||
raise HTTPException(status_code=400, detail=f"Space {space_id} not found")
|
||||
config = VectorStoreConfig(
|
||||
name=space.name, llm_client=self.llm_client, model_name=None
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=self._system_app,
|
||||
)
|
||||
vector_store_connector = self.create_vector_store(space.name)
|
||||
# delete vectors
|
||||
vector_store_connector.delete_vector_name(space.name)
|
||||
document_query = KnowledgeDocumentEntity(space=space.name)
|
||||
@@ -360,14 +355,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
|
||||
vector_ids = docuemnt.vector_ids
|
||||
if vector_ids is not None:
|
||||
config = VectorStoreConfig(
|
||||
name=space.name, llm_client=self.llm_client, model_name=None
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=self._system_app,
|
||||
)
|
||||
vector_store_connector = self.create_vector_store(space.name)
|
||||
# delete vector by ids
|
||||
vector_store_connector.delete_by_ids(vector_ids)
|
||||
# delete chunks
|
||||
@@ -498,25 +486,9 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
chunk_parameters: ChunkParameters,
|
||||
) -> None:
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
embedding_factory = self._system_app.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create()
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
space = self.get({"id": space_id})
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
max_chunks_once_load=self._serve_config.max_chunks_once_load,
|
||||
max_threads=self._serve_config.max_threads,
|
||||
llm_client=self.llm_client,
|
||||
model_name=None,
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
system_app=self._system_app,
|
||||
storage_connector = self.storage_manager.get_storage_connector(
|
||||
space.name, space.vector_type
|
||||
)
|
||||
knowledge = None
|
||||
if not space.domain_type or (
|
||||
@@ -531,17 +503,17 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
doc.gmt_modified = datetime.now()
|
||||
self._document_dao.update_knowledge_document(doc)
|
||||
asyncio.create_task(
|
||||
self.async_doc_embedding(
|
||||
knowledge, chunk_parameters, vector_store_connector, doc, space
|
||||
self.async_doc_process(
|
||||
knowledge, chunk_parameters, storage_connector, doc, space
|
||||
)
|
||||
)
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
|
||||
@trace("async_doc_embedding")
|
||||
async def async_doc_embedding(
|
||||
self, knowledge, chunk_parameters, vector_store_connector, doc, space
|
||||
@trace("async_doc_process")
|
||||
async def async_doc_process(
|
||||
self, knowledge, chunk_parameters, storage_connector, doc, space
|
||||
):
|
||||
"""async document embedding into vector db
|
||||
"""async document process into storage
|
||||
Args:
|
||||
- knowledge: Knowledge
|
||||
- chunk_parameters: ChunkParameters
|
||||
@@ -572,13 +544,11 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
vector_ids = [chunk.chunk_id for chunk in chunk_docs]
|
||||
else:
|
||||
max_chunks_once_load = (
|
||||
vector_store_connector._index_store_config.max_chunks_once_load
|
||||
)
|
||||
max_threads = vector_store_connector._index_store_config.max_threads
|
||||
max_chunks_once_load = self.config.max_chunks_once_load
|
||||
max_threads = self.config.max_threads
|
||||
assembler = await EmbeddingAssembler.aload_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
index_store=vector_store_connector.index_client,
|
||||
index_store=storage_connector,
|
||||
chunk_parameters=chunk_parameters,
|
||||
)
|
||||
|
||||
|
152
packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py
Normal file
152
packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""RAG STORAGE MANAGER manager."""
|
||||
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from dbgpt import BaseComponent
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.embedding import EmbeddingFactory
|
||||
from dbgpt.storage.base import IndexStoreBase
|
||||
from dbgpt.storage.full_text.base import FullTextStoreBase
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt_ext.storage.full_text.elasticsearch import ElasticDocumentStore
|
||||
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import BuiltinKnowledgeGraph
|
||||
|
||||
|
||||
class StorageManager(BaseComponent):
|
||||
"""RAG STORAGE MANAGER manager."""
|
||||
|
||||
name = ComponentType.RAG_STORAGE_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
"""Create a new ConnectorManager."""
|
||||
self.system_app = system_app
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Init component."""
|
||||
self.system_app = system_app
|
||||
|
||||
def storage_config(self):
|
||||
"""Storage config."""
|
||||
app_config = self.system_app.config.configs.get("app_config")
|
||||
return app_config.rag.storage
|
||||
|
||||
def get_storage_connector(
|
||||
self, index_name: str, storage_type: str, llm_model: Optional[str] = None
|
||||
) -> IndexStoreBase:
|
||||
"""Get storage connector."""
|
||||
supported_vector_types = self.get_vector_supported_types
|
||||
storage_config = self.storage_config()
|
||||
if storage_type in supported_vector_types:
|
||||
return self.create_vector_store(index_name)
|
||||
elif storage_type == "KnowledgeGraph":
|
||||
if not storage_config.graph:
|
||||
raise ValueError(
|
||||
"Graph storage is not configured.please check your config."
|
||||
"reference configs/dbgpt-graphrag.toml"
|
||||
)
|
||||
return self.create_kg_store(index_name, llm_model)
|
||||
elif storage_type == "FullText":
|
||||
if not storage_config.full_text:
|
||||
raise ValueError(
|
||||
"FullText storage is not configured.please check your config."
|
||||
"reference configs/dbgpt-bm25-rag.toml"
|
||||
)
|
||||
return self.create_full_text_store(index_name)
|
||||
else:
|
||||
raise ValueError(f"Does not support storage type {storage_type}")
|
||||
|
||||
def create_vector_store(self, index_name) -> VectorStoreBase:
|
||||
"""Create vector store."""
|
||||
app_config = self.system_app.config.configs.get("app_config")
|
||||
storage_config = app_config.rag.storage
|
||||
embedding_factory = self.system_app.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create()
|
||||
vector_store_config: VectorStoreConfig = storage_config.vector
|
||||
return vector_store_config.create_store(
|
||||
name=index_name, embedding_fn=embedding_fn
|
||||
)
|
||||
|
||||
def create_kg_store(
|
||||
self, index_name, llm_model: Optional[str] = None
|
||||
) -> BuiltinKnowledgeGraph:
|
||||
"""Create knowledge graph store."""
|
||||
app_config = self.system_app.config.configs.get("app_config")
|
||||
rag_config = app_config.rag
|
||||
storage_config = app_config.rag.storage
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
llm_client = DefaultLLMClient(worker_manager=worker_manager)
|
||||
embedding_factory = self.system_app.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embedding_fn = embedding_factory.create()
|
||||
if storage_config.graph:
|
||||
graph_config = storage_config.graph
|
||||
graph_config.llm_model = llm_model
|
||||
if hasattr(graph_config, "enable_summary") and graph_config.enable_summary:
|
||||
from dbgpt_ext.storage.knowledge_graph.community_summary import (
|
||||
CommunitySummaryKnowledgeGraph,
|
||||
)
|
||||
|
||||
return CommunitySummaryKnowledgeGraph(
|
||||
config=storage_config.graph,
|
||||
name=index_name,
|
||||
llm_client=llm_client,
|
||||
vector_store_config=storage_config.vector,
|
||||
kg_extract_top_k=rag_config.kg_extract_top_k,
|
||||
kg_extract_score_threshold=rag_config.kg_extract_score_threshold,
|
||||
kg_community_top_k=rag_config.kg_community_top_k,
|
||||
kg_community_score_threshold=rag_config.kg_community_score_threshold,
|
||||
kg_triplet_graph_enabled=rag_config.kg_triplet_graph_enabled,
|
||||
kg_document_graph_enabled=rag_config.kg_document_graph_enabled,
|
||||
kg_chunk_search_top_k=rag_config.kg_chunk_search_top_k,
|
||||
kg_extraction_batch_size=rag_config.kg_extraction_batch_size,
|
||||
kg_community_summary_batch_size=rag_config.kg_community_summary_batch_size,
|
||||
kg_embedding_batch_size=rag_config.kg_embedding_batch_size,
|
||||
kg_similarity_top_k=rag_config.kg_similarity_top_k,
|
||||
kg_similarity_score_threshold=rag_config.kg_similarity_score_threshold,
|
||||
kg_enable_text_search=rag_config.kg_enable_text_search,
|
||||
kg_text2gql_model_enabled=rag_config.kg_text2gql_model_enabled,
|
||||
kg_text2gql_model_name=rag_config.kg_text2gql_model_name,
|
||||
embedding_fn=embedding_fn,
|
||||
kg_max_chunks_once_load=rag_config.max_chunks_once_load,
|
||||
kg_max_threads=rag_config.max_threads,
|
||||
)
|
||||
return BuiltinKnowledgeGraph(
|
||||
config=storage_config.graph,
|
||||
name=index_name,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
def create_full_text_store(self, index_name) -> FullTextStoreBase:
|
||||
"""Create Full Text store."""
|
||||
app_config = self.system_app.config.configs.get("app_config")
|
||||
rag_config = app_config.rag
|
||||
storage_config = app_config.rag.storage
|
||||
return ElasticDocumentStore(
|
||||
es_config=storage_config.full_text,
|
||||
name=index_name,
|
||||
k1=rag_config.bm25_k1,
|
||||
b=rag_config.bm25_b,
|
||||
)
|
||||
|
||||
@property
|
||||
def get_vector_supported_types(self) -> List[str]:
|
||||
"""Get all supported types."""
|
||||
support_types = []
|
||||
vector_store_classes = _get_all_subclasses()
|
||||
for vector_cls in vector_store_classes:
|
||||
support_types.append(vector_cls.__type__)
|
||||
return support_types
|
||||
|
||||
|
||||
def _get_all_subclasses() -> List[Type[VectorStoreConfig]]:
|
||||
"""Get all subclasses of cls."""
|
||||
|
||||
return VectorStoreConfig.__subclasses__()
|
Reference in New Issue
Block a user